FedProto: Federated Prototype Learning across Heterogeneous Clients

Yue Tan1, Guodong Long1, Lu Liu1, Tianyi Zhou2,3, Qinghua Lu4, Jing Jiang1, Chengqi Zhang 1
Abstract

Heterogeneity across clients in federated learning (FL) usually hinders the optimization convergence and generalization performance when the aggregation of clients’ knowledge occurs in the gradient space. For example, clients may differ in terms of data distribution, network latency, input/output space, and/or model architecture, which can easily lead to the misalignment of their local gradients. To improve the tolerance to heterogeneity, we propose a novel federated prototype learning (FedProto) framework in which the clients and server communicate the abstract class prototypes instead of the gradients. FedProto aggregates the local prototypes collected from different clients, and then sends the global prototypes back to all clients to regularize the training of local models. The training on each client aims to minimize the classification error on the local data while keeping the resulting local prototypes sufficiently close to the corresponding global ones. Moreover, we provide a theoretical analysis to the convergence rate of FedProto under non-convex objectives. In experiments, we propose a benchmark setting tailored for heterogeneous FL, with FedProto outperforming several recent FL approaches on multiple datasets.

Introduction

Federated learning (FL) is widely used in multiple applications to enable collaborative learning across a variety of clients without sharing private data. It aims at training a global model on a centralized server while all data are distributed over many local clients and cannot be freely transmitted for privacy or communication concerns (McMahan et al. 2017). The iterative process of FL has two steps: (1) each local client is synchronized by the global model and then trained using its local data; and (2) the server updates the global model by aggregating all the local models. Considering that the model aggregation occurs in the gradient space, traditional FL still has some practical challenges caused by the heterogeneity of data and model (Kairouz et al. 2019). Efficient algorithms suitable to overcome both these two challenges have not yet been fully developed or systematically examined.

To tackle the statistical heterogeneity of data distributions, one straightforward solution is to maintain multiple global models for different local distributions, e.g., the works for clustered FL (Sattler, Müller, and Samek 2020). Another widely studied strategy is personalized FL (Tan et al. 2021) where a personalized model is generated for each client by leveraging both global and local information. Nevertheless, most of these methods depend on gradient-based aggregation, resulting in high communication costs and heavy reliance on homogeneous local models.

However, in real-world applications, model heterogeneity is common because of varying hardware and computation capabilities across clients (Long et al. 2020). Knowledge Distillation (KD)-based FL (Lin et al. 2020) addresses this challenge by transferring the teacher model’s knowledge to student models with different model architectures. However, these methods require an extra public dataset to align the student and teacher models’ outputs, increasing the computation costs. Moreover, the performance of KD-based FL can significantly degrade with the increase in the distribution divergence between the public dataset and on-client datasets that are usually non-IID.

Inspired by prototype learning, merging the prototypes over heterogeneous datasets can effectively integrate the feature representations from diverse data distributions (Snell, Swersky, and Zemel 2017; Liu et al. 2020; Dvornik, Schmid, and Mairal 2020). On-client intelligent agents in the FL system can share knowledge by exchanging information in terms of representations, despite statistical and model heterogeneity (Cai et al. 2020; Li et al. 2021). For example, when we talk about “dog”, different people will have a unique “imagined picture” or “prototype” to represent the concept “dog”. Their prototypes may be slightly diverse due to different life experience and visual memory. Exchanging these concept-specific prototypes across people enables them to acquire more knowledge about the concept “dog”. Treating each FL client as a human-like intelligent agent, the core idea of our method is to exchange prototypes rather than share model parameters or raw data, which can naturally match the knowledge acquisition behavior of humans.

In this paper, we propose a novel prototype aggregation-based FL framework where only prototypes are transmitted between the server and clients. The proposed solution does not require model parameters or gradients to be aggregated, so it has a huge potential to be a robust framework for various heterogeneous FL scenarios. Concretely, each client can have different model architectures and input/output spaces, but they can still exchange information by sharing prototypes. Each abstract prototype represents a class by the mean representations transformed from the observed samples belonging to the same class. Aggregating the prototypes allows for efficient communication across heterogeneous clients.

Our main contributions can be summarized as follows:

  • We propose a benchmark setting tailored for heterogeneous FL that considers a more general heterogeneous scenario across local clients.

  • We present a novel FL method that significantly improves the communication efficiency in the heterogeneous setting. To the best of our knowledge, we are the first to propose prototype aggregation-based FL.

  • We theoretically provide a convergence guarantee for our method and carefully derive the convergence rate under non-convex conditions.

  • Extensive experiments show the superiority of our proposed method in terms of communication efficiency and test performance in several benchmark datasets.

Related Work

Heterogeneous Federated Learning

Statistical heterogeneity across clients (also known as the non-IID problem) is the most important challenge of FL. FedProx (Li et al. 2020) proposed a local regularization term to optimize each client’s local model. Some recent studies (Arivazhagan et al. 2019; Liang et al. 2020; Deng, Kamani, and Mahdavi 2020) train personalized models to leverage both globally shared information and the personalized part (Tan et al. 2021; Jiang, Ji, and Long 2020). The third solution is to provide multiple global models by clustering the local models (Mansour et al. 2020; Ghosh et al. 2020; Sattler, Müller, and Samek 2020) into multiple groups or clusters. Recently, self-supervised learning strategies are incorporated into the local training phase to handle the heterogeneity challenges (Li, He, and Song 2021; Liu et al. 2021a; Yang et al. 2021). (Fallah, Mokhtari, and Ozdaglar 2020) applies meta-training strategy for personalized FL.

Heterogeneous model architecture is another major challenging scenario of FL. The recently proposed KD-based FL (Lin et al. 2020; Jeong et al. 2018; Li and Wang 2020; Long et al. 2021) can serve as an alternative solution to address this challenge. In particular, with the assumption of adding a shared toy dataset in the federated setting, these KD-based FL methods can distill knowledge from a teacher model to student models with different model architectures. Some recent studies have also attempted to combine the neural architecture search with federated learning (Zhu, Zhang, and Jin 2020; He, Annavaram, and Avestimehr 2020; Singh et al. 2020), which can be applied to discover a customized model architecture for each group of clients with different hardware capabilities and configurations. A collective learning platform is proposed to handle heterogeneous architectures without access to the local training data and architectures in (Hoang et al. 2019). Moreover, functionality-based neural matching across local models (Wang et al. 2020a) can aggregate neurons with similar functionality regardless of the variance of the model architectures.

However, most of these mentioned FL methods focus on only one heterogeneous challenging scenario. All of them use gradient-based aggregation methods which will raise concerns about communication efficiency and gradient-based attacks (Zhu, Liu, and Han 2019; Chen et al. 2020; Liu et al. 2021b; Zheng et al. 2021).

Prototype Learning

The concept of prototypes (the mean of multiple features) has been explored in a variety of tasks. In image classification, a prototype can be a proxy of a class and is calculated as the mean of the feature vectors within every class (Snell, Swersky, and Zemel 2017). In action recognition, the features of a video in different timestamps can be averaged to serve as the representation of the video (Simonyan and Zisserman 2014; Xue et al. 2021). Aggregated local features can serve as descriptors for image retrieval (Babenko and Lempitsky 2015). Averaging word embeddings as the representation of a sentence can achieve competitive performance on multiple NLP benchmarks (Wieting et al. 2015). The authors in (Hoang et al. 2020) use prototypes to represent task-agnostic information in distributed machine learning and propose a new fusion paradigm to integrate those prototypes to generate a new model for a new task. In (Michieli and Ozay 2021), prototype margins are used to optimize visual feature representations for FL. In our paper, we borrow the concept of prototypes to represent one class and apply prototype aggregation in the setting of heterogeneous FL.

In general, prototypes are widely used in learning scenarios with a limited number of training samples (Snell, Swersky, and Zemel 2017). This learning scenario is consistent with the latent assumption of cross-client FL: that each client has a limited number of instances to independently train a model with the desired performance. The assumption has been widely supported by the FL-based benchmark datasets (Caldas et al. 2018; He et al. 2020) and in related applications, such as healthcare (Rieke et al. 2020; Xu et al. 2020) and street image object detection (Luo et al. 2019).

Refer to caption
Figure 1: An overview of FedProto in the heterogeneous setting. For example, the first client is to recognize the digits 2,3,4234{2,3,4}, while the m𝑚m-th client is to recognize the digits 4,545{4,5}. First, the clients update their local prototype sets by minimizing the loss of classification error Ssubscript𝑆\mathcal{L}_{S} and the distance between global prototypes and local prototypes Rsubscript𝑅\mathcal{L}_{R}. Then, the clients send their prototypes to the central server. The central server generates global prototypes and returns them to all clients to regularize the training of local models.

Problem Setting

Heterogeneous Federated Learning Setting

In federated learning, each client owns a local private dataset Disubscript𝐷𝑖{D}_{i} drawn from distribution i(x,y)subscript𝑖𝑥𝑦\mathbb{P}_{i}(x,y), where x𝑥x and y𝑦y denote the input features and corresponding class labels, respectively. Usually, clients share a model (ω;x)𝜔𝑥\mathcal{F}(\omega;x) with the same architecture and hyperparameters. This model is parameterized by learnable weights ω𝜔\omega and input features x𝑥x. The objective function of FedAvg (McMahan et al. 2017) is:

argminωi=1m|Di|NS((ω;x),y),subscriptargmin𝜔superscriptsubscript𝑖1𝑚subscript𝐷𝑖𝑁subscript𝑆𝜔𝑥𝑦\operatorname*{arg\,min}_{\omega}\sum_{i=1}^{m}\frac{|D_{i}|}{N}\mathcal{L}_{S}(\mathcal{F}(\omega;x),y), (1)

where ω𝜔\omega is the global model’s parameters, m𝑚m denotes the number of clients, N𝑁N is the total number of instances over all clients, \mathcal{F} is the shared model, and Ssubscript𝑆\mathcal{L}_{S} is a general definition of any supervised learning task (e.g., a cross-entropy loss).

However, in a real-world FL environment, each client may represent a mobile phone with a specific user behavior pattern or a sensor deployed in a particular location, leading to statistical and/or model heterogeneous environment. In the statistical heterogeneity setting, isubscript𝑖\mathbb{P}_{i} varies across clients, indicating heterogeneous input/output space for x𝑥x and y𝑦y. For example, isubscript𝑖\mathbb{P}_{i} on different clients can be the data distributions over different subsets of classes. In the model heterogeneity setting, isubscript𝑖\mathcal{F}_{i} varies across clients, indicating different model architectures and hyperparameters. For the i𝑖i-th client, the training procedure is to minimize the loss as defined below:

argminω1,ω2,,ωmi=1m|Di|NS(i(ωi;x),y).subscriptargminsubscript𝜔1subscript𝜔2subscript𝜔𝑚superscriptsubscript𝑖1𝑚subscript𝐷𝑖𝑁subscript𝑆subscript𝑖subscript𝜔𝑖𝑥𝑦\operatorname*{arg\,min}_{\omega_{1},\omega_{2},\dots,\omega_{m}}\sum_{i=1}^{m}\frac{|D_{i}|}{N}\mathcal{L}_{S}(\mathcal{F}_{i}(\omega_{i};x),y). (2)

Most existing methods cannot well handle the heterogeneous settings above. In particular, the fact that isubscript𝑖\mathcal{F}_{i} has a different model architecture would cause ωisubscript𝜔𝑖\omega_{i} to have a different format and size. Thus, the global model’s parameter ω𝜔\omega cannot be optimized by averaging ωisubscript𝜔𝑖\omega_{i}. To tackle this challenge, we propose to communicate and aggregate prototypes in FL.

Prototype-Based Aggregation Setting

Heterogeneous FL focuses on the robustness to tackle heterogeneous input/output spaces, distributions and model architectures. For example, the datasets Disubscript𝐷𝑖D_{i} and Dksubscript𝐷𝑘D_{k} on two clients i𝑖i and k𝑘k may take different statistical distributions of labels. This is common for a photo classification APP installed on mobile clients, where the server needs to recognize many classes ={C(1),C(2),}superscript𝐶1superscript𝐶2\mathbb{C}=\{{C}^{(1)},{C}^{(2)},\dots\}, while each client only needs to recognize a few classes that constitute a subset of \mathbb{C}. The class set can vary across clients, though there are overlaps.

In general, the deep learning-based models comprise two parts: (1) representation layers (a.k.a. embedding functions) to transform the input from the original feature space to the embedding space; and (2) decision layers to make a classification decision for a given learning task.

Representation layers

The embedding function of the i𝑖i-th client is fi(ϕi)subscript𝑓𝑖subscriptitalic-ϕ𝑖f_{i}(\phi_{i}) parameterized by ϕisubscriptitalic-ϕ𝑖\phi_{i}. We denote hi=fi(ϕi;x)subscript𝑖subscript𝑓𝑖subscriptitalic-ϕ𝑖𝑥h_{i}=f_{i}(\phi_{i};x) as the embeddings of x𝑥x.

Decision layers

Given a supervised learning task, a prediction for x𝑥x can be generated by the function gi(νi)subscript𝑔𝑖subscript𝜈𝑖g_{i}(\nu_{i}) parameterized by νisubscript𝜈𝑖\nu_{i}. So, the labelling function can be written as i(ϕi,νi)=gi(νi)fi(ϕi)subscript𝑖subscriptitalic-ϕ𝑖subscript𝜈𝑖subscript𝑔𝑖subscript𝜈𝑖subscript𝑓𝑖subscriptitalic-ϕ𝑖\mathcal{F}_{i}(\phi_{i},\nu_{i})=g_{i}(\nu_{i})\circ f_{i}(\phi_{i}), and we use ωisubscript𝜔𝑖\omega_{i} to represent (ϕi,νi)subscriptitalic-ϕ𝑖subscript𝜈𝑖(\phi_{i},\nu_{i}) for short.

Prototype

We define a prototype C(j)superscript𝐶𝑗{C}^{(j)} to represent the j𝑗j-th class in \mathbb{C}. For the i𝑖i-th client, the prototype is the mean value of the embedding vectors of instances in class j𝑗j,

Ci(j)=1|Di,j|(x,y)Di,jfi(ϕi;x),superscriptsubscript𝐶𝑖𝑗1subscript𝐷𝑖𝑗subscript𝑥𝑦subscript𝐷𝑖𝑗subscript𝑓𝑖subscriptitalic-ϕ𝑖𝑥C_{i}^{(j)}=\frac{1}{|D_{i,j}|}\sum_{(x,y)\in D_{i,j}}f_{i}(\phi_{i};x), (3)

where Di,jsubscript𝐷𝑖𝑗D_{i,j}, a subset of the local dataset Disubscript𝐷𝑖D_{i}, is comprised of training instances belonging to the j𝑗j-th class.

Prototype-based model inference

In the inference stage of the learning task, we can simply predict the label y^^𝑦\hat{y} to an instance x𝑥x by measuring the L2 distance between the instance’s representational vector f(ϕ;x)𝑓italic-ϕ𝑥f(\phi;x) and the prototype C(j)superscript𝐶𝑗{C}^{(j)} as follows:

y^=argminjf(ϕ;x)C(j)2.^𝑦subscriptargmin𝑗subscriptnorm𝑓italic-ϕ𝑥superscript𝐶𝑗2\hat{y}=\operatorname*{arg\,min}_{j}||f(\phi;x)-C^{(j)}||_{2}. (4)

Methodology

We propose a solution for heterogeneous FL that uses prototypes as the key component for exchanging information across the server and the clients.

An overview of the proposed framework is shown in Figure 1. The central server receives local prototype sets C1,C2,,Cmsubscript𝐶1subscript𝐶2subscript𝐶𝑚C_{1},C_{2},\ldots,C_{m} from m𝑚m local clients, and then aggregates the prototypes by averaging them. In the heterogeneous FL setting, these prototype sets overlap but are not the same. Taking the MNIST dataset as an example, the first client is to recognize the digits 2,3,4234{2,3,4}, while another client is to recognize the digits 4,545{4,5}. These are two different handwritten digits set; nonetheless, there is an overlap. The server automatically aggregates prototypes from the overlapping class space across the clients.

Using prototypes in FL, we do not need to exchange gradients or model parameters, which means that the proposed solution can tackle heterogeneous model architectures. Moreover, the prototype-based FL does not require each client to provide the same classes, meaning the heterogeneous class spaces are well supported. Thus, heterogeneity challenges in FL can be addressed.

Optimization Objective

The objective of FedProto is to solve a joint optimization problem on a distributed network. FedProto applies prototype-based communication, which allows a local model to align its prototypes with other local models while minimizing the sum of loss for all clients’ local learning tasks. The objective of federated prototype learning across heterogeneous clients can be formulated as

argmin{C¯(j)}j=1||i=1m|Di|NS(i(ωi;x),y)+limit-fromsubscriptargminsuperscriptsubscriptsuperscript¯𝐶𝑗𝑗1superscriptsubscript𝑖1𝑚subscript𝐷𝑖𝑁subscript𝑆subscript𝑖subscript𝜔𝑖𝑥𝑦\displaystyle\operatorname*{arg\,min}_{\left\{\bar{C}^{(j)}\right\}_{j=1}^{|\mathbb{C}|}}\sum_{i=1}^{m}\frac{|D_{i}|}{N}\mathcal{L}_{S}(\mathcal{F}_{i}(\omega_{i};x),y)+ (5)
λj=1||i=1m|Di,j|NjR(C¯i(j),Ci(j)),𝜆superscriptsubscript𝑗1superscriptsubscript𝑖1𝑚subscript𝐷𝑖𝑗subscript𝑁𝑗subscript𝑅superscriptsubscript¯𝐶𝑖𝑗superscriptsubscript𝐶𝑖𝑗\displaystyle\lambda\cdot\sum_{j=1}^{|\mathbb{C}|}\sum_{i=1}^{m}\frac{|D_{i,j}|}{N_{j}}\mathcal{L}_{R}(\bar{C}_{i}^{(j)},C_{i}^{(j)}),

where Ssubscript𝑆\mathcal{L}_{S} is the loss of supervised learning (as defined in Eq. (2)) and Rsubscript𝑅\mathcal{L}_{R} is a regularization term that measures the distance (we use L2 distance) between a local prototype C(j)superscript𝐶𝑗C^{(j)} and the corresponding global prototypes C¯(j)superscript¯𝐶𝑗\bar{C}^{(j)}. N𝑁N is the total number of instances over all clients, and Njsubscript𝑁𝑗N_{j} is the number of instances belonging to class j𝑗j over all clients.

The optimization problem can be addressed by alternate minimization that iterates the following two steps: (1) minimization w.r.t. each ωisubscript𝜔𝑖\omega_{i} with C¯i(j)superscriptsubscript¯𝐶𝑖𝑗\bar{C}_{i}^{(j)} fixed; and (2) minimization w.r.t. C¯i(j)superscriptsubscript¯𝐶𝑖𝑗\bar{C}_{i}^{(j)} with all ωisubscript𝜔𝑖\omega_{i} fixed. In a distributed setting, step (1) reduces to conventional supervised learning on each client using its local data, while step (2) aggregates local prototypes from local clients on the server end. Further details concerning these two steps can be seen in Algorithm 1.

Algorithm 1 FedProto

Input: Di,ωi,i=1,,mformulae-sequencesubscript𝐷𝑖subscript𝜔𝑖𝑖1𝑚D_{i},\omega_{i},i=1,\cdots,m
Server executes:

1:  Initialize global prototype set {C¯(j)}superscript¯𝐶𝑗\left\{\bar{C}^{(j)}\right\} for all classes.
2:  for each round T=1,2,𝑇12T=1,2,... do
3:     for each client i𝑖i in parallel do
4:        Cisubscript𝐶𝑖absentC_{i}\leftarrow LocalUpdate (i,C¯i)𝑖subscript¯𝐶𝑖\left(i,\bar{C}_{i}\right)
5:     end for
6:     Update global prototype by Eq. 6.
7:     Update local prototype set Cisubscript𝐶𝑖C_{i} with prototypes in {C¯(j)}superscript¯𝐶𝑗\{\bar{C}^{(j)}\}
8:  end for

LocalUpdate(i,C¯i)𝑖subscript¯𝐶𝑖\left(i,\bar{C}_{i}\right):

1:  for each local epoch do
2:     for batch (xi,yi)Disubscript𝑥𝑖subscript𝑦𝑖subscript𝐷𝑖\left(x_{i},y_{i}\right)\in D_{i} do
3:        Compute local prototype by Eq. 3.
4:        Compute loss by Eq. 7 using local prototypes.
5:        Update local model according to the loss.
6:     end for
7:  end for
8:  return  C(i)superscript𝐶𝑖C^{(i)}

Global Prototype Aggregation

Given the data and model heterogeneity in the participating clients, the optimal model parameters for each client are not the same. This means that gradient-based communication cannot sufficiently provide useful information to each client. However, the same label space allows the participating clients to share the same embedding space and information can be efficiently exchanged across heterogeneous clients by aggregating prototypes according to the classes they belong to.

Given a class j𝑗j, the server receives prototypes from a set of clients that have class j𝑗j. A global prototype C¯(j)superscript¯𝐶𝑗\bar{C}^{(j)} for class j𝑗j is generated after the prototype aggregating operation,

C¯(j)=1|𝒩j|i𝒩j|Di,j|NjCi(j),superscript¯𝐶𝑗1subscript𝒩𝑗subscript𝑖subscript𝒩𝑗subscript𝐷𝑖𝑗subscript𝑁𝑗subscriptsuperscript𝐶𝑗𝑖\bar{C}^{(j)}=\frac{1}{\left|\mathcal{N}_{j}\right|}\sum_{i\in{\mathcal{N}_{j}}}\frac{|D_{i,j}|}{N_{j}}C^{(j)}_{i}, (6)

where Ci(j)superscriptsubscript𝐶𝑖𝑗C_{i}^{(j)} denotes the prototype of class j𝑗j from client i𝑖i, and 𝒩jsubscript𝒩𝑗\mathcal{N}_{j} denotes the set of clients that have class j𝑗j.

Local Model Update

The client needs to update the local model to generate a consistent prototype across the clients. To this end, a regularization term is added to the local loss function, enabling the local prototypes Ci(j)superscriptsubscript𝐶𝑖𝑗C_{i}^{(j)} to approach global prototypes C¯i(j)superscriptsubscript¯𝐶𝑖𝑗\bar{C}_{i}^{(j)} while minimizing the loss of the classification error. In particular, the loss function is defined as follows:

(Di,ωi)=S(i(ωi;x),y)+λR(C¯i,Ci),subscript𝐷𝑖subscript𝜔𝑖subscript𝑆subscript𝑖subscript𝜔𝑖𝑥𝑦𝜆subscript𝑅subscript¯𝐶𝑖subscript𝐶𝑖\mathcal{L}(D_{i},\omega_{i})=\mathcal{L}_{S}(\mathcal{F}_{i}(\omega_{i};x),y)+\lambda\cdot\mathcal{L}_{R}\left(\bar{C}_{i},C_{i}\right), (7)

where λ𝜆\lambda is an importance weight, and Rsubscript𝑅\mathcal{L}_{R} is the regularization term that can be defined as:

R=jd(Ci(j),C¯i(j)),subscript𝑅subscript𝑗𝑑superscriptsubscript𝐶𝑖𝑗superscriptsubscript¯𝐶𝑖𝑗\mathcal{L}_{R}=\sum_{j}d(C_{i}^{(j)},\bar{C}_{i}^{(j)}), (8)

where d𝑑d is a distance metric of local generated prototypes C(j)superscript𝐶𝑗C^{(j)} and global aggregated prototypes C¯(j)superscript¯𝐶𝑗\bar{C}^{(j)}. The distance measurement can take a variety of forms, such as L1 distance, L2 distance, and earth mover’s distance.

Convergence Analysis

We provide insights into the convergence analysis for FedProto. We denote the local objective function defined in Eq. 7 as \mathcal{L} with a subscript indicating the number of iterations and make the following assumptions similar to existing general frameworks (Wang et al. 2020b; Li et al. 2020).

Assumption 1.

(Lipschitz Smooth). Each local objective function is L1subscript𝐿1L_{1}-Lipschitz smooth, which means that the gradient of local objective function is L1subscript𝐿1L_{1}-Lipschitz continuous,

t1\displaystyle\|\nabla\mathcal{L}_{{t_{1}}} t22L1ωi,t1ωi,t22,evaluated-atsubscriptsubscript𝑡22subscript𝐿1subscriptnormsubscript𝜔𝑖subscript𝑡1subscript𝜔𝑖subscript𝑡22\displaystyle-\nabla\mathcal{L}_{{t_{2}}}\|_{2}\leq L_{1}\|\omega_{{i,t_{1}}}-\omega_{{i,t_{2}}}\|_{2}, (9)
t1,t2>0,i{1,2,,m}.formulae-sequencefor-allsubscript𝑡1subscript𝑡20𝑖12𝑚\displaystyle\forall t_{1},t_{2}>0,i\in\{1,2,\dots,m\}.

This also implies the following quadratic bound,

t1t2subscriptsubscript𝑡1subscriptsubscript𝑡2\displaystyle\mathcal{L}_{{t_{1}}}-\mathcal{L}_{{t_{2}}} t2,(ωi,t1ωi,t2)+L12ωi,t1ωi,t222,absentsubscriptsubscript𝑡2subscript𝜔𝑖subscript𝑡1subscript𝜔𝑖subscript𝑡2subscript𝐿12superscriptsubscriptnormsubscript𝜔𝑖subscript𝑡1subscript𝜔𝑖subscript𝑡222\displaystyle\leq\langle\nabla\mathcal{L}_{{t_{2}}},(\omega_{{i,t_{1}}}-\omega_{{i,t_{2}}})\rangle+\frac{L_{1}}{2}\|\omega_{{i,t_{1}}}-\omega_{{i,t_{2}}}\|_{2}^{2}, (10)
t1,t2>0,i{1,2,,m}.formulae-sequencefor-allsubscript𝑡1subscript𝑡20𝑖12𝑚\displaystyle\forall t_{1},t_{2}>0,\quad i\in\{1,2,\dots,m\}.
Assumption 2.

(Unbiased Gradient and Bounded Variance). The stochastic gradient gi,t=(ωt,ξt)subscript𝑔𝑖𝑡subscript𝜔𝑡subscript𝜉𝑡g_{i,t}=\nabla\mathcal{L}(\omega_{t},\xi_{t}) is an unbiased estimator of the local gradient for each client. Suppose its expectation

𝔼ξiDi[gi,t]=(ωi,t)=t,i{1,2,,m},formulae-sequencesubscript𝔼similar-tosubscript𝜉𝑖subscript𝐷𝑖delimited-[]subscript𝑔𝑖𝑡subscript𝜔𝑖𝑡subscript𝑡for-all𝑖12𝑚{\mathbb{E}}_{\xi_{i}\sim D_{i}}[g_{i,t}]=\nabla\mathcal{L}(\omega_{i,t})=\nabla\mathcal{L}_{t},\forall i\in\{1,2,\dots,m\}, (11)

and its variance is bounded by σ2superscript𝜎2\sigma^{2}:

𝔼[gi,t(ωi,t)22]σ2,i{1,2,,m},σ20.formulae-sequence𝔼delimited-[]superscriptsubscriptnormsubscript𝑔𝑖𝑡subscript𝜔𝑖𝑡22superscript𝜎2formulae-sequencefor-all𝑖12𝑚superscript𝜎20{\mathbb{E}}[{\|g_{i,t}-\nabla\mathcal{L}(\omega_{i,t})\|}_{2}^{2}]\leq\sigma^{2},\forall i\in\{1,2,\dots,m\},\sigma^{2}\geq 0. (12)
Assumption 3.

(Bounded Expectation of Euclidean norm of Stochastic Gradients).The expectation of the stochastic gradient is bounded by G𝐺G:

𝔼[gi,t2]G,i{1,2,,m}.formulae-sequence𝔼delimited-[]subscriptnormsubscript𝑔𝑖𝑡2𝐺for-all𝑖12𝑚{\mathbb{E}}[\|g_{i,t}\|_{2}]\leq G,\forall i\in\{1,2,\dots,m\}. (13)
Assumption 4.

(Lipschitz Continuity). Each local embedding function is L2subscript𝐿2L_{2}-Lipschitz continuous, that is,

fi(ϕi,t1)\displaystyle\|f_{i}(\phi_{i,t_{1}}) fi(ϕi,t2)L2ϕi,t1ϕi,t22,\displaystyle-f_{i}(\phi_{i,t_{2}})\|\leq L_{2}\|\phi_{i,t_{1}}-\phi_{i,t_{2}}\|_{2}, (14)
t1,t2>0,i{1,2,,m}.formulae-sequencefor-allsubscript𝑡1subscript𝑡20𝑖12𝑚\displaystyle\forall t_{1},t_{2}>0,i\in\{1,2,\dots,m\}.

Based on the above assumptions, we present the theoretical results for the non-convex problem. The expected decrease per round is given in Theorem 1. We denote e{1/2,1,2,,Ee\in\{1/2,1,2,\dots,E} as the local iteration, and t𝑡t as the global communication round. Moreover, tE𝑡𝐸tE represents the time step before prototype aggregation, and tE+1/2𝑡𝐸12tE+1/2 represents the time step between prototype aggregation and the first iteration of the current round.

Theorem 1.

(One-round deviation). Let Assumption 1 to 4 hold. For an arbitrary client, after every communication round, we have,

𝔼[(t+1)E+1/2]𝔼delimited-[]subscript𝑡1𝐸12absent\displaystyle{\mathbb{E}}[\mathcal{L}_{(t+1)E+1/2}]\leq tE+1/2(ηL1η22)e=1/2E1tE+e22subscript𝑡𝐸12𝜂subscript𝐿1superscript𝜂22superscriptsubscript𝑒12𝐸1superscriptsubscriptnormsubscript𝑡𝐸𝑒22\displaystyle\mathcal{L}_{tE+1/2}-\left(\eta-\frac{L_{1}\eta^{2}}{2}\right)\sum_{e=1/2}^{E-1}\|\nabla\mathcal{L}_{tE+e}\|_{2}^{2} (15)
+L1Eη22σ2+λL2ηEG.subscript𝐿1𝐸superscript𝜂22superscript𝜎2𝜆subscript𝐿2𝜂𝐸𝐺\displaystyle+\frac{L_{1}E\eta^{2}}{2}\sigma^{2}+\lambda L_{2}\eta EG.

Theorem 1 indicates the deviation bound of the local objective function for an arbitrary client after each communication round. Convergence can be guaranteed when there is a certain expected one-round decrease, which can be achieved by choosing appropriate η𝜂\eta and λ𝜆\lambda.

Corollary 1.

(Non-convex FedProto convergence). The loss function \mathcal{L} of an arbitrary client monotonously decreases in every communication round when

ηe<2(e=1/2etE+e22λL2EG)L1(e=1/2etE+e22+Eσ2),subscript𝜂superscript𝑒2superscriptsubscript𝑒12superscript𝑒superscriptsubscriptnormsubscript𝑡𝐸𝑒22𝜆subscript𝐿2𝐸𝐺subscript𝐿1superscriptsubscript𝑒12superscript𝑒superscriptsubscriptnormsubscript𝑡𝐸𝑒22𝐸superscript𝜎2\eta_{e^{\prime}}<\frac{2(\sum_{e=1/2}^{e^{\prime}}\|\nabla\mathcal{L}_{tE+e}\|_{2}^{2}-\lambda L_{2}EG)}{L_{1}(\sum_{e=1/2}^{e^{\prime}}\|\nabla\mathcal{L}_{tE+e}\|_{2}^{2}+E\sigma^{2})}, (16)

where e=1/2,1,,E1superscript𝑒121𝐸1e^{\prime}=1/2,1,\dots,E-1, and

λt<tE+1/222L2EG.subscript𝜆𝑡superscriptsubscriptnormsubscript𝑡𝐸1222subscript𝐿2𝐸𝐺\lambda_{t}<\frac{\|\nabla\mathcal{L}_{tE+1/2}\|_{2}^{2}}{L_{2}EG}. (17)

Thus, the loss function converges.

Corollary 1 is to ensure the expected deviation of \mathcal{L} to be negative, so the loss function converges. It can guide the choice of appropriate values for the learning rate η𝜂\eta and the importance weight λ𝜆\lambda to guarantee the convergence.

Theorem 2.

(Non-convex convergence rate of FedProto). Let Assumption 1 to 4 hold and Δ=0Δsubscript0superscript\Delta=\mathcal{L}_{0}-\mathcal{L}^{*} where superscript\mathcal{L}^{*} refers to the local optimum. For an arbitrary client, given any ϵ>0italic-ϵ0\epsilon>0, after

T=2ΔEϵ(2ηL1η2)Eη(L1ησ2+2λL2G)𝑇2Δ𝐸italic-ϵ2𝜂subscript𝐿1superscript𝜂2𝐸𝜂subscript𝐿1𝜂superscript𝜎22𝜆subscript𝐿2𝐺T=\frac{2\Delta}{E\epsilon(2\eta-L_{1}\eta^{2})-E\eta(L_{1}\eta\sigma^{2}+2\lambda L_{2}G)} (18)

communication rounds of FedProto, we have

1TEt=0T1e=1/2E1𝔼[tE+e22]<ϵ,1𝑇𝐸superscriptsubscript𝑡0𝑇1superscriptsubscript𝑒12𝐸1𝔼delimited-[]superscriptsubscriptnormsubscript𝑡𝐸𝑒22italic-ϵ\frac{1}{TE}\sum_{t=0}^{T-1}\sum_{e=1/2}^{E-1}{\mathbb{E}}[\|\nabla\mathcal{L}_{tE+e}\|_{2}^{2}]<\epsilon, (19)

if

η<2(ϵλL2G)L1(ϵ+σ2)andλ<ϵL2G.formulae-sequence𝜂2italic-ϵ𝜆subscript𝐿2𝐺subscript𝐿1italic-ϵsuperscript𝜎2𝑎𝑛𝑑𝜆italic-ϵsubscript𝐿2𝐺\eta<\frac{2(\epsilon-\lambda L_{2}G)}{L_{1}(\epsilon+\sigma^{2})}\ and\ \ \lambda<\frac{\epsilon}{L_{2}G}.

Theorem 2 provides the convergence rate, which can confine the expected L2-norm of gradients to any bound, denoted as ϵitalic-ϵ\epsilon, after carefully selecting the number of communication rounds T𝑇T and hyperparameters including η𝜂\eta and λ𝜆\lambda. The smaller ϵitalic-ϵ\epsilon is, the larger T𝑇T is, which means that the tighter the bound is, more communication rounds is required. A detailed proof and analysis are given in Appendix B.

Discussion

In this section, we discuss the superiority of FedProto from three perspectives: model inference, communication efficiency, and privacy preserving.

Model Inference

Unlike many FL methods, the global model in FedProto is not a classifier but a set of class prototypes. When a new client is added to the network, one can initialize its local model with the representation layers of a pre-trained model, e.g. a ResNet18 on ImageNet, and random decision layers. Then, the local client will download the global prototypes of the classes covered in its local dataset and fine-tune the local model by minimizing the local objective. This can support new clients with novel model architectures and spend less time fine-tuning the model on heterogeneous datasets.

Communication Efficiency

Our proposed method only transmits prototypes between the server and clients. In general, the size of the prototypes is usually much smaller than the size of the model parameters. Taking MNIST as an example, the prototype size is 50 for each class, while the number of model parameters is 21,500. More details can be found in the experimental section.

Dataset Method Stdev Test Average Acc # of Comm Rounds # of Comm Params (×103absentsuperscript103\times 10^{3}) n=3𝑛3n=3 n=4𝑛4n=4 n=5𝑛5n=5 MNIST Local 2 94.05±plus-or-minus\pm2.93 93.35±plus-or-minus\pm3.26 92.92±plus-or-minus\pm3.17 0 0 FeSEM (Xie et al. 2020) 2 95.26±plus-or-minus\pm3.48 97.06±plus-or-minus\pm2.72 96.31±plus-or-minus\pm2.41 150 430 FedProx (Li et al. 2020) 2 96.26±plus-or-minus\pm2.89 96.40±plus-or-minus\pm3.33 95.65±plus-or-minus\pm3.38 110 430 FedPer (Arivazhagan et al. 2019) 2 95.57±plus-or-minus\pm2.96 96.44±plus-or-minus\pm2.62 95.55±plus-or-minus\pm3.13 100 106 FedAvg (McMahan et al. 2017) 2 95.04±plus-or-minus\pm6.48 94.32±plus-or-minus\pm4.89 93.22±plus-or-minus\pm4.39 150 430 FedRep (Collins et al. 2021) 2 94.96±plus-or-minus\pm2.78 95.18±plus-or-minus\pm3.80 94.94±plus-or-minus\pm2.81 100 110 FedProto 2 97.13±plus-or-minus\pm0.30 96.80±plus-or-minus\pm0.41 96.70±plus-or-minus\pm0.29 100 4 FedProto-mh 2 97.07±plus-or-minus\pm0.50 96.65±plus-or-minus\pm0.31 96.22±plus-or-minus\pm0.36 100 4 FEMNIST Local 1 92.50±plus-or-minus\pm10.42 91.16±plus-or-minus\pm5.64 87.91±plus-or-minus\pm8.44 0 0 FeSEM (Xie et al. 2020) 1 93.39±plus-or-minus\pm6.75 91.06±plus-or-minus\pm6.43 89.61±plus-or-minus\pm7.89 200 16,000 FedProx (Li et al. 2020) 1 94.53±plus-or-minus\pm5.33 90.71±plus-or-minus\pm6.24 91.33±plus-or-minus\pm7.32 300 16,000 FedPer (Arivazhagan et al. 2019) 1 93.47±plus-or-minus\pm5.44 90.22±plus-or-minus\pm7.63 87.73±plus-or-minus\pm9.64 250 102 FedAvg (McMahan et al. 2017) 1 94.50±plus-or-minus\pm5.29 91.39±plus-or-minus\pm5.23 90.95±plus-or-minus\pm7.22 300 16,000 FedRep (Collins et al. 2021) 1 93.36±plus-or-minus\pm5.34 91.41±plus-or-minus\pm5.89 89.98±plus-or-minus\pm6.88 200 102 FedProto 1 96.82±plus-or-minus\pm1.75 94.93±plus-or-minus\pm1.61 93.67±plus-or-minus\pm2.23 120 4 FedProto-mh 1 97.10±plus-or-minus\pm1.63 94.83±plus-or-minus\pm1.60 93.76±plus-or-minus\pm2.30 120 4 CIFAR10 Local 1 79.72±plus-or-minus\pm9.45 67.62±plus-or-minus\pm7.15 58.64±plus-or-minus\pm6.57 0 0 FeSEM (Xie et al. 2020) 1 80.19±plus-or-minus\pm3.31 76.40±plus-or-minus\pm3.23 74.17±plus-or-minus\pm3.51 120 235,000 FedProx (Li et al. 2020) 1 83.25±plus-or-minus\pm2.44 79.20±plus-or-minus\pm1.31 76.19±plus-or-minus\pm2.23 150 235,000 FedPer (Arivazhagan et al. 2019) 1 84.38±plus-or-minus\pm4.58 78.73±plus-or-minus\pm4.59 76.21±plus-or-minus\pm4.27 130 225,000 FedAvg (McMahan et al. 2017) 1 81.72±plus-or-minus\pm2.77 76.77±plus-or-minus\pm2.37 75.74±plus-or-minus\pm2.61 150 235,000 FedRep (Collins et al. 2021) 1 81.44±plus-or-minus\pm10.48 76.93±plus-or-minus\pm7.46 73.36±plus-or-minus\pm7.04 110 225,000 FedProto 1 84.49±plus-or-minus\pm1.97 79.12±plus-or-minus\pm2.03 77.08±plus-or-minus\pm1.98 110 41 FedProto-mh 1 83.63±plus-or-minus\pm1.60 79.49±plus-or-minus\pm1.78 76.94±plus-or-minus\pm1.33 110 41

Table 1: Comparison of FL methods on three benchmark datasets with non-IID split over clients. The best results are in bold. It appears that FedProto, compared to baselines, achieves higher accuracy while using much fewer communicated parameters.

Privacy Preserving

The proposed FedProto requires the exchange of prototypes rather than model parameters between the server and the clients. This property brings benefits to FL in terms of privacy preserving. First, prototypes naturally protect the data privacy, because they are 1D-vectors generated by averaging the low-dimension representations of samples from the same class, which is an irreversible process. Second, attackers cannot reconstruct raw data from prototypes without the access to local models. Moreover, FedProto can be integrated with various privacy-preserving techniques to further enhance the reliability of the system.

Experiments

Training Setups

Datasets and local models

We implement the typical federated setting where each client owns its local data and transmits/receives information to/from the central server. We use three popular benchmark datasets: MNIST (LeCun 1998), FEMNIST (Caldas et al. 2018) and CIFAR10 (Krizhevsky, Hinton et al. 2009). We consider a multi-layer CNN which consists of 2 convolutional layers then 2 fully connected layers for both MNIST and FEMNIST, and ResNet18 (He et al. 2016) for CIFAR10.

Local tasks

Each client learns a supervised learning task. In particular, to illustrate the local task, we borrow the concept of n𝑛n-way k𝑘k-shot from few-shot learning where n𝑛n controls the number of classes and k𝑘k controls the number of training instances per class. To mimic the heterogeneous scenario, we randomly change the value of n𝑛n and k𝑘k in different clients. We define an average value for n𝑛n and k𝑘k, and then add a random noise to each user’s n𝑛n as well as k𝑘k. The purpose of the variance of n𝑛n is to control the heterogeneity of the class space, while the purpose of the variance of k𝑘k is to control the imbalance in data size.

Baselines of FL

We study the performance of FedProto under both the statistical and model heterogeneous settings (FedProto-mh) and make comparisons with baselines, including Local where an individual model is trained for each client without any communication with others, FedAvg (McMahan et al. 2017), FedProx (Li et al. 2020), FeSEM (Xie et al. 2020), FedPer (Arivazhagan et al. 2019), and FedRep (Collins et al. 2021).

Refer to caption
(a) FedProto
Refer to caption
(b) FedAvg
Refer to caption
(c) FeSEM
Refer to caption
(d) FedPer
Figure 2: t-SNE visualization of the samples and/or prototypes produced by FedProto and other FL methods. We consider 20 clients for MNIST. The average number of classes per client is n=3𝑛3n=3. (a) FedProto: Samples within the same class have multiple centers with each center representing local prototype of a client. The global prototype is the central point of the class samples. (b) FedAvg: Samples within the same class cluster in the same area. (c) FeSEM: Samples within the same class gather in several clusters according to the algorithm parameters. (d) FedPer: Each cluster indicates one class in a specific client.

Implementation Details

We implement FedProto and the baseline methods in PyTorch. We use 20 clients for all datasets and all clients are sampled in each communication round. The average size of each class in each client is set to be 100. For MNIST and FEMNIST dataset, our initial set of hyperparameters was taken directly from the default set of hyperparamters in (McMahan et al. 2017). For CIFAR10, ResNet18 pre-trained on ImageNet (Krizhevsky, Sutskever, and Hinton 2017) is used as the initial model. The initial average test accuracy of the pre-trained network on CIFAR10 is 27.55%percent\%. A detailed setup including the choice of hyperparameters is given in Appendix A.

Performance in Non-IID Federated Setting

We compare FedProto with other baseline methods that are either classical FL methods or FL methods with an emphasis on statistical heterogeneity. All methods are adapted to fit this heterogeneous setting.

Statistical heterogeneity simulations

In our setting, we assume that all clients perform learning tasks with heterogeneous statistical distributions. In order to simulate different levels of heterogeneity, we fix the standard deviation to be 1 or 2, aiming to create heterogeneity in both class spaces and data sizes, which is common in real-world scenarios.

Model heterogeneity simulations

For the model heterogeneous setting, we consider minor differences in model architectures across clients. In MNIST and FEMNIST, the number of output channels in the convolutional layers is set to either 18, 20 or 22, while in CIFAR10, the stride of convolutional layers is set differently across different clients. This kind of model heterogeneity brings about challenges for model parameter averaging because the parameters in different clients are not always the same size.

The average test accuracy over all clients is reported in Table 1. It can be seen that FedProto achieves the highest accuracy and the least variance in most cases, ensuring uniformity among heterogeneous clients.

Communication efficiency

Communication costs have always been posed as a challenge in FL, considering several limitations in existing communication channels. Therefore, we also report the number of communication rounds required for convergence and the number of parameters communicated per round in Table 1. It can be seen that the number of parameters communicated per round in FedProto is much lower than in the case of FedAvg. Furthermore, FedProto requires the fewest communication rounds for the local optimization. This suggests that when the heterogeneity level is high across the clients, sharing more parameters does not always lead to better results. It is more important to identify which part to share in order to benefit the current system to a great extent. More performance results are shown in Appendix A.

Visualization of prototypes achieved by FedProto

We visualize the samples in MNIST test set by t-SNE (Van der Maaten and Hinton 2008). In Figure 2(a), small points in different colors represent samples in different classes, with large points representing corresponding global prototypes. In Figure 2(b), 2(c) and 2(d), the points in different colors refer to the representations of samples belonging to different classes. Better generalization means that there are more samples within the same class cluster in the same area, which can be achieved in a centralized setting, while better personalization means that it is easier to determine to which client the samples belong. It can be seen that samples within the same class but from various clients are close but separable in FedProto. This indicates that FedProto is more successful in achieving the balance between generalization and personalization, while other methods lacks either the generalization or the personalization ability.

Scalability of FedProto on varying number of samples

Refer to caption
Figure 3: Average test accuracy of FedProto and FedAvg on MNIST with varying number of samples in each class.

Figure 3 shows that FedProto can scale to scenarios with fewer samples available on clients. The test accuracy consistently decreases when there are fewer samples for training, but FedProto drops more slowly than FedAvg as a result of its adaptability and scalability on various data sizes.

FedProto under varying λ𝜆\lambda

Figure 4 shows the varying performance under different values of λ𝜆\lambda in Eq. (5). We tried a set of values selected from [0,4]04[0,4] and reported the average test accuracy and proto distance loss with n𝑛n=3, k𝑘k=100 in FEMNIST dataset. The best value of λ𝜆\lambda is 111 in this scenario. As λ𝜆\lambda increases, the proto distance loss (regularization term) decreases, while the average test accuracy experiences a sharp rise from λ𝜆\lambda=0 to λ𝜆\lambda=1 before a drop in the number of 6%percent\%, demonstrating the efficacy of prototype aggregation.

Refer to caption
Figure 4: Average test accuracy on FEMNIST under varying importance weight λ𝜆\lambda.

Conclusion

In this paper, we propose a novel prototype aggregation-based FL method to tackle challenging FL scenarios with heterogeneous input/output spaces, data distributions, and model architectures. The proposed method collaboratively trains intelligent models by exchanging prototypes rather than gradients, which offers new insights for designing prototype-based FL. The effectiveness of the proposed method has been comprehensively analyzed from both theoretical and experimental perspectives.

References

  • Arivazhagan et al. (2019) Arivazhagan, M. G.; Aggarwal, V.; Singh, A. K.; and Choudhary, S. 2019. Federated learning with personalization layers. arXiv preprint arXiv:1912.00818.
  • Babenko and Lempitsky (2015) Babenko, A.; and Lempitsky, V. 2015. Aggregating local deep features for image retrieval. In Proceedings of the IEEE international conference on computer vision, 1269–1277.
  • Cai et al. (2020) Cai, T.; Li, J.; Mian, A. S.; Sellis, T.; Yu, J. X.; et al. 2020. Target-aware holistic influence maximization in spatial social networks. IEEE Transactions on Knowledge and Data Engineering.
  • Caldas et al. (2018) Caldas, S.; Duddu, S. M. K.; Wu, P.; Li, T.; Konečnỳ, J.; McMahan, H. B.; Smith, V.; and Talwalkar, A. 2018. Leaf: A benchmark for federated settings. arXiv: 1812.01097.
  • Chen et al. (2020) Chen, C.; Zhang, J.; Tung, A. K.; Kankanhalli, M.; and Chen, G. 2020. Robust federated recommendation system. arXiv preprint arXiv:2006.08259.
  • Collins et al. (2021) Collins, L.; Hassani, H.; Mokhtari, A.; and Shakkottai, S. 2021. Exploiting Shared Representations for Personalized Federated Learning. International Conference on Machine Learning.
  • Deng, Kamani, and Mahdavi (2020) Deng, Y.; Kamani, M. M.; and Mahdavi, M. 2020. Adaptive Personalized Federated Learning. arXiv:2003.13461.
  • Dvornik, Schmid, and Mairal (2020) Dvornik, N.; Schmid, C.; and Mairal, J. 2020. Selecting relevant features from a universal representation for few-shot classification.
  • Fallah, Mokhtari, and Ozdaglar (2020) Fallah, A.; Mokhtari, A.; and Ozdaglar, A. 2020. Personalized Federated Learning with Theoretical Guarantees: A Model-Agnostic Meta-Learning Approach. In Advances in Neural Information Processing Systems.
  • Ghosh et al. (2020) Ghosh, A.; Chung, J.; Yin, D.; and Ramchandran, K. 2020. An Efficient Framework for Clustered Federated Learning. In Advances in Neural Information Processing Systems.
  • He, Annavaram, and Avestimehr (2020) He, C.; Annavaram, M.; and Avestimehr, S. 2020. FedNAS: Federated Deep Learning via Neural Architecture Search. In Proceedings of the IEEE conference on computer vision and pattern recognition.
  • He et al. (2020) He, C.; Li, S.; So, J.; Zhang, M.; Wang, H.; Wang, X.; Vepakomma, P.; Singh, A.; Qiu, H.; Shen, L.; et al. 2020. Fedml: A research library and benchmark for federated machine learning. arXiv:2007.13518.
  • He et al. (2016) He, K.; Zhang, X.; Ren, S.; and Sun, J. 2016. Deep residual learning for image recognition. In Proceedings of the IEEE conference on computer vision and pattern recognition, 770–778.
  • Hoang et al. (2019) Hoang, M.; Hoang, N.; Low, B. K. H.; and Kingsford, C. 2019. Collective model fusion for multiple black-box experts. In International Conference on Machine Learning, 2742–2750. PMLR.
  • Hoang et al. (2020) Hoang, N.; Lam, T.; Low, B. K. H.; and Jaillet, P. 2020. Learning Task-Agnostic Embedding of Multiple Black-Box Experts for Multi-Task Model Fusion. In International Conference on Machine Learning, 4282–4292. PMLR.
  • Jeong et al. (2018) Jeong, E.; Oh, S.; Kim, H.; Park, J.; Bennis, M.; and Kim, S.-L. 2018. Communication-efficient on-device machine learning: Federated distillation and augmentation under non-IID private data. In Advances in Neural Information Processing Systems.
  • Jiang, Ji, and Long (2020) Jiang, J.; Ji, S.; and Long, G. 2020. Decentralized knowledge acquisition for mobile internet applications. World Wide Web, 1–17.
  • Kairouz et al. (2019) Kairouz, P.; McMahan, H. B.; Avent, B.; Bellet, A.; Bennis, M.; Bhagoji, A. N.; et al. 2019. Advances and open problems in federated learning. arXiv:1912.04977.
  • Krizhevsky, Hinton et al. (2009) Krizhevsky, A.; Hinton, G.; et al. 2009. Learning multiple layers of features from tiny images.
  • Krizhevsky, Sutskever, and Hinton (2017) Krizhevsky, A.; Sutskever, I.; and Hinton, G. E. 2017. Imagenet classification with deep convolutional neural networks. Communications of the ACM, 60(6): 84–90.
  • LeCun (1998) LeCun, Y. 1998. The MNIST database of handwritten digits. http://yann. lecun. com/exdb/mnist/.
  • Li and Wang (2020) Li, D.; and Wang, J. 2020. Fedmd: Heterogenous federated learning via model distillation. In Advances in Neural Information Processing Systems.
  • Li, He, and Song (2021) Li, Q.; He, B.; and Song, D. 2021. Model-Contrastive Federated Learning. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, 10713–10722.
  • Li et al. (2020) Li, T.; Sahu, A. K.; Zaheer, M.; Sanjabi, M.; Talwalkar, A.; and Smith, V. 2020. Federated optimization in heterogeneous networks. MLSys.
  • Li et al. (2021) Li, Z.; Wang, X.; Li, J.; and Zhang, Q. 2021. Deep attributed network representation learning of complex coupling and interaction. Knowledge-Based Systems, 212: 106618.
  • Liang et al. (2020) Liang, P. P.; Liu, T.; Ziyin, L.; Salakhutdinov, R.; and Morency, L.-P. 2020. Think Locally, Act Globally: Federated Learning with Local and Global Representations. Advances in Neural Information Processing Systems.
  • Lin et al. (2020) Lin, T.; Kong, L.; Stich, S. U.; and Jaggi, M. 2020. Ensemble Distillation for Robust Model Fusion in Federated Learning. In Advances in Neural Information Processing Systems.
  • Liu et al. (2020) Liu, L.; Hamilton, W. L.; Long, G.; Jiang, J.; and Larochelle, H. 2020. A Universal Representation Transformer Layer for Few-Shot Image Classification. In International Conference on Learning Representations.
  • Liu et al. (2021a) Liu, Y.; Pan, S.; Jin, M.; Zhou, C.; Xia, F.; and Yu, P. S. 2021a. Graph self-supervised learning: A survey. arXiv preprint arXiv:2103.00111.
  • Liu et al. (2021b) Liu, Y.; Pan, S.; Wang, Y. G.; Xiong, F.; Wang, L.; and Lee, V. 2021b. Anomaly Detection in Dynamic Graphs via Transformer. IEEE Transactions on Knowledge and Data Engineering.
  • Long et al. (2021) Long, G.; Shen, T.; Tan, Y.; Gerrard, L.; Clarke, A.; and Jiang, J. 2021. Federated learning for privacy-preserving open innovation future on digital health. In Humanity Driven AI. Springer.
  • Long et al. (2020) Long, G.; Tan, Y.; Jiang, J.; and Zhang, C. 2020. Federated Learning for Open Banking. In Federated Learning, 240–254. Springer.
  • Luo et al. (2019) Luo, J.; Wu, X.; Luo, Y.; Huang, A.; Huang, Y.; Liu, Y.; and Yang, Q. 2019. Real-world image datasets for federated learning. arXiv:1910.11089.
  • Mansour et al. (2020) Mansour, Y.; Mohri, M.; Ro, J.; and Suresh, A. T. 2020. Three approaches for personalization with applications to federated learning. arXiv:2002.10619.
  • McMahan et al. (2017) McMahan, H. B.; Moore, E.; Ramage, D.; et al. 2017. Communication-efficient learning of deep networks from decentralized data. AISTATS.
  • Michieli and Ozay (2021) Michieli, U.; and Ozay, M. 2021. Prototype Guided Federated Learning of Visual Feature Representations. arXiv preprint arXiv:2105.08982.
  • Rieke et al. (2020) Rieke, N.; Hancox, J.; Li, W.; Milletari, F.; Roth, H. R.; Albarqouni, S.; Bakas, S.; Galtier, M. N.; Landman, B. A.; Maier-Hein, K.; et al. 2020. The future of digital health with federated learning. NPJ digital medicine, 3(1): 1–7.
  • Sattler, Müller, and Samek (2020) Sattler, F.; Müller, K.-R.; and Samek, W. 2020. Clustered federated learning: Model-agnostic distributed multitask optimization under privacy constraints. IEEE transactions on neural networks and learning systems.
  • Simonyan and Zisserman (2014) Simonyan, K.; and Zisserman, A. 2014. Two-Stream Convolutional Networks for Action Recognition in Videos. In Advances in Neural Information Processing Systems, 568–576.
  • Singh et al. (2020) Singh, I.; Zhou, H.; Yang, K.; Ding, M.; Lin, B.; and Xie, P. 2020. Differentially-private federated neural architecture search. In FL-International Conference on Machine Learning Workshop.
  • Snell, Swersky, and Zemel (2017) Snell, J.; Swersky, K.; and Zemel, R. 2017. Prototypical Networks for Few-shot Learning. Advances in Neural Information Processing Systems, 30: 4077–4087.
  • Tan et al. (2021) Tan, A. Z.; Yu, H.; Cui, L.; and Yang, Q. 2021. Towards personalized federated learning. arXiv preprint arXiv:2103.00710.
  • Van der Maaten and Hinton (2008) Van der Maaten, L.; and Hinton, G. 2008. Visualizing data using t-SNE. Journal of machine learning research, 9(11).
  • Wang et al. (2020a) Wang, H.; Yurochkin, M.; Sun, Y.; Papailiopoulos, D.; and Khazaeni, Y. 2020a. Federated Learning with Matched Averaging. In International Conference on Learning Representations.
  • Wang et al. (2020b) Wang, J.; Liu, Q.; Liang, H.; Joshi, G.; and Poor, H. V. 2020b. Tackling the Objective Inconsistency Problem in Heterogeneous Federated Optimization. Advances in neural information processing systems.
  • Wieting et al. (2015) Wieting, J.; Bansal, M.; Gimpel, K.; and Livescu, K. 2015. Towards universal paraphrastic sentence embeddings. arXiv:1511.08198.
  • Xie et al. (2020) Xie, M.; Long, G.; Shen, T.; Wang, X.; Tianyi, Z.; and Jiang, J. 2020. Multi-center Federated Learning. arXiv:2005.01026.
  • Xu et al. (2020) Xu, J.; Glicksberg, B. S.; Su, C.; Walker, P.; Bian, J.; and Wang, F. 2020. Federated learning for healthcare informatics. Journal of Healthcare Informatics Research, 1–19.
  • Xue et al. (2021) Xue, G.; Zhong, M.; Li, J.; Chen, J.; Zhai, C.; and Kong, R. 2021. Dynamic network embedding survey. arXiv preprint arXiv:2103.15447.
  • Yang et al. (2021) Yang, Y.; Guan, Z.; Li, J.; Zhao, W.; Cui, J.; and Wang, Q. 2021. Interpretable and efficient heterogeneous graph convolutional network. IEEE Transactions on Knowledge and Data Engineering.
  • Zheng et al. (2021) Zheng, Y.; Jin, M.; Liu, Y.; Chi, L.; Phan, K. T.; and Chen, Y.-P. P. 2021. Generative and Contrastive Self-Supervised Learning for Graph Anomaly Detection. IEEE Transactions on Knowledge and Data Engineering.
  • Zhu, Zhang, and Jin (2020) Zhu, H.; Zhang, H.; and Jin, Y. 2020. From federated learning to federated neural architecture search: a survey. Complex & Intelligent Systems, 1–19.
  • Zhu, Liu, and Han (2019) Zhu, L.; Liu, Z.; and Han, S. 2019. Deep Leakage from Gradients. In Advances in Neural Information Processing Systems, 14774–14784.

We present the related supplements in following sections.

Experimental Details and Extra Results

Experimental Details

Local clients are trained by SGD optimizer, with a learning rate of 0.010.010.01 and momentum of 0.50.50.5. Regarding the crucial hyperparameter λ𝜆\lambda, we tune the best λ𝜆\lambda from a limited candidate set by grid search. The best λ𝜆\lambda values for MNIST, FEMNIST and CIFAR10 are 111, 111 and 0.10.10.1, respectively. The number of local epochs and local batch size are set to be 1 and 8, respectively, for all datasets. The heterogeneity level of clients is controlled by the standard deviation of n𝑛n. The higher this is, the more heterogeneous the clients are.

Extra Results

The complete experimental results show the performance of FedProto and FedProto-mh on three benchmark datasets MNIST, FEMNIST, and CIFAR10. Compared with existing FL methods, FedProto yields higher test accuracy while resulting in lower communication costs under different heterogeneous settings. Additionally, it can be used in model heterogeneous scenarios and achieves performance similar to that in homogeneous scenarios.

For MNIST, we evaluate local test sets and report the evaluation results in Table 2. It appears that FedProto achieves strong performance with low communication cost. The local average test accuracy of FedProto is greater than for the FeSEM, FedProx, FedPer and FedAvg algorithms in all the settings.

For FEMNIST, the evaluation results are reported in Table 3. We consider the standard deviation of n𝑛n to be 1 and 2. The results show that, for FedProto, the variance of the accuracy across clients is much smaller than for other FL methods, thus ensuring uniformity among heterogeneous clients. FedProto allows us to better utilize the local FEMNIST dataset distribution while using around 0.025%percent0.0250.025\% of the total parameters communicated.

For CIFAR10, as can be seen in Table 4, FedProto converges faster in the presence of heterogeneity in most cases. In FedProto and FedProto-mh, the number of parameters communicated per round is much lower than the baseline methods, meaning greatly reduced communication costs.

Dataset Method Stdev of n𝑛n Test Average Acc # of Comm Rounds # of Comm Params (×103absentsuperscript103\times 10^{3})
n=3𝑛3n=3 n=4𝑛4n=4 n=5𝑛5n=5
MNIST Local 2 94.05±plus-or-minus\pm2.93 93.35±plus-or-minus\pm3.26 92.92±plus-or-minus\pm3.17 0 0
3 93.44±plus-or-minus\pm3.57 94.24±plus-or-minus\pm2.49 93.97±plus-or-minus\pm2.97
FeSEM 2 95.26±plus-or-minus\pm3.48 97.06±plus-or-minus\pm2.72 96.31±plus-or-minus\pm2.41 150 430
3 96.40±plus-or-minus\pm3.35 95.82±plus-or-minus\pm3.94 95.98±plus-or-minus\pm2.46
FedProx 2 96.26±plus-or-minus\pm2.89 96.40±plus-or-minus\pm3.33 95.65±plus-or-minus\pm3.38 110 430
3 96.65±plus-or-minus\pm3.28 95.25±plus-or-minus\pm3.73 95.34±plus-or-minus\pm2.85
FedPer 2 95.57±plus-or-minus\pm2.96 96.44±plus-or-minus\pm2.62 95.55±plus-or-minus\pm3.13 100 106
3 96.57±plus-or-minus\pm2.65 95.93±plus-or-minus\pm2.76 96.07±plus-or-minus\pm2.80
FedAvg 2 91.40±plus-or-minus\pm6.48 94.32±plus-or-minus\pm4.89 93.22±plus-or-minus\pm4.39 150 430
3 94.57±plus-or-minus\pm4.91 91.99±plus-or-minus\pm6.89 92.19±plus-or-minus\pm3.97
FedRep 2 94.96±plus-or-minus\pm2.78 95.18±plus-or-minus\pm3.80 94.94±plus-or-minus\pm2.81 100 110
3 95.01±plus-or-minus\pm3.92 95.55±plus-or-minus\pm2.79 95.38±plus-or-minus\pm2.97
FedProto 2 97.13±plus-or-minus\pm0.30 96.80±plus-or-minus\pm0.41 96.70±plus-or-minus\pm0.29 100 4
3 96.71±plus-or-minus\pm0.43 96.87±plus-or-minus\pm0.28 96.47±plus-or-minus\pm0.23
FedProto-mh 2 97.07±plus-or-minus\pm0.50 96.65±plus-or-minus\pm0.31 96.22±plus-or-minus\pm0.36 100 4
3 96.48±plus-or-minus\pm0.43 96.84±plus-or-minus\pm0.33 95.56±plus-or-minus\pm0.31
Table 2: Comparison of FL methods on MNIST with non-IID split over clients. Best results are in bold. It appears that FedProto, compared to the baseline methods, achieves higher accuracy while using much less communicated parameters.
Dataset Method Stdev of n𝑛n Test Average Acc # of Comm Rounds # of Comm Params (×103absentsuperscript103\times 10^{3})
n=3𝑛3n=3 n=4𝑛4n=4 n=5𝑛5n=5
FEMNIST Local 1 92.50±plus-or-minus\pm10.42 91.16±plus-or-minus\pm5.64 87.91±plus-or-minus\pm8.44 0 0
2 92.11±plus-or-minus\pm6.02 90.34±plus-or-minus\pm6.42 89.70±plus-or-minus\pm6.33
FeSEM 1 93.39±plus-or-minus\pm6.75 91.06±plus-or-minus\pm6.43 89.61±plus-or-minus\pm7.89 200 16,000
2 94.19±plus-or-minus\pm4.90 93.52±plus-or-minus\pm4.47 90.77±plus-or-minus\pm6.70
FedProx 1 94.53±plus-or-minus\pm5.33 90.71±plus-or-minus\pm6.24 91.33±plus-or-minus\pm7.32 300 16,000
2 93.49±plus-or-minus\pm5.30 93.74±plus-or-minus\pm5.02 89.49±plus-or-minus\pm6.74
FedPer 1 93.47±plus-or-minus\pm5.44 90.22±plus-or-minus\pm7.63 87.73±plus-or-minus\pm9.64 250 102
2 92.27±plus-or-minus\pm6.16 91.99±plus-or-minus\pm6.33 87.54±plus-or-minus\pm8.14
FedAvg 1 94.50±plus-or-minus\pm5.29 91.39±plus-or-minus\pm5.23 90.95±plus-or-minus\pm7.22 300 16,000
2 94.13±plus-or-minus\pm4.92 93.02±plus-or-minus\pm5.77 89.80±plus-or-minus\pm6.94
FedRep 1 93.36±plus-or-minus\pm5.34 91.41±plus-or-minus\pm5.89 89.98±plus-or-minus\pm6.88 200 102
2 92.28±plus-or-minus\pm5.40 91.56±plus-or-minus\pm7.02 88.23±plus-or-minus\pm6.97
FedProto 1 96.82±plus-or-minus\pm1.75 94.93±plus-or-minus\pm1.61 93.67±plus-or-minus\pm2.23 120 4
2 94.93±plus-or-minus\pm1.29 94.69±plus-or-minus\pm1.50 93.03±plus-or-minus\pm2.50
FedProto-mh 1 97.10±plus-or-minus\pm1.63 94.83±plus-or-minus\pm1.60 93.76±plus-or-minus\pm2.30 120 4
2 95.33±plus-or-minus\pm1.30 94.98±plus-or-minus\pm1.69 92.94±plus-or-minus\pm2.34
Table 3: Comparison of FL methods on FEMNIST with non-IID split over clients. Best results are in bold. It appears that FedProto, compared to the baseline methods, achieves higher accuracy while using much fewer communicated parameters.
Dataset Method Stdev of n𝑛n Test Average Acc # of Comm Rounds # of Comm Params (×104absentsuperscript104\times 10^{4})
n=3𝑛3n=3 n=4𝑛4n=4 n=5𝑛5n=5
CIFAR10 Local 1 79.72±plus-or-minus\pm9.45 67.62±plus-or-minus\pm7.15 58.64±plus-or-minus\pm6.57 0 0
2 68.15±plus-or-minus\pm9.88 61.03±plus-or-minus\pm11.83 58.81±plus-or-minus\pm12.90
FeSEM 1 80.19±plus-or-minus\pm3.31 76.40±plus-or-minus\pm3.23 74.17±plus-or-minus\pm3.51 120 2.35×104absentsuperscript104\times 10^{4}
2 76.12±plus-or-minus\pm4.15 72.11±plus-or-minus\pm3.48 70.89±plus-or-minus\pm3.39
FedProx 1 83.25±plus-or-minus\pm2.44 79.20±plus-or-minus\pm1.31 76.19±plus-or-minus\pm2.23 150 2.35×104absentsuperscript104\times 10^{4}
2 79.83±plus-or-minus\pm2.35 72.56±plus-or-minus\pm1.90 71.39±plus-or-minus\pm2.36
FedPer 1 84.38±plus-or-minus\pm4.58 78.73±plus-or-minus\pm4.59 76.21±plus-or-minus\pm4.27 130 2.25×104absentsuperscript104\times 10^{4}
2 84.51±plus-or-minus\pm4.39 73.31±plus-or-minus\pm4.76 72,43±plus-or-minus\pm4.55
FedAvg 1 81.72±plus-or-minus\pm2.77 76.77±plus-or-minus\pm2.37 75.74±plus-or-minus\pm2.61 150 2.35×104absentsuperscript104\times 10^{4}
2 78.99±plus-or-minus\pm2.34 72.73±plus-or-minus\pm2.58 70.93±plus-or-minus\pm2.82
FedRep 1 81.44±plus-or-minus\pm10.48 76.93±plus-or-minus\pm7.46 73.36±plus-or-minus\pm7.04 110 2.25×104absentsuperscript104\times 10^{4}
2 76.70±plus-or-minus\pm11.79 73.54±plus-or-minus\pm11.42 70.30±plus-or-minus\pm8.00
FedProto 1 84.49±plus-or-minus\pm1.97 79.12±plus-or-minus\pm2.03 77.08±plus-or-minus\pm1.98 110 4.10
2 81.75±plus-or-minus\pm1.39 74.98±plus-or-minus\pm1.61 71.17±plus-or-minus\pm1.29
FedProto-mh 1 83.63±plus-or-minus\pm1.60 79.49±plus-or-minus\pm1.78 76.94±plus-or-minus\pm1.33 110 4.10
2 79.90±plus-or-minus\pm1.08 75.78±plus-or-minus\pm1.05 72.67±plus-or-minus\pm1.09
Table 4: Comparison of FL methods on CIFAR10 with non-IID split over clients. Best results are in bold. It appears that FedProto, compared to the baseline methods, achieves higher accuracy while using much fewer communicated parameters.

Convergence Analysis for FedProto

Additional Notation

Here, additional variables are introduced to better represent the process of local model update. Let fi(ϕi):dxdc:subscript𝑓𝑖subscriptitalic-ϕ𝑖superscriptsubscript𝑑𝑥superscriptsubscript𝑑𝑐f_{i}(\phi_{i}):{\mathbb{R}}^{d_{x}}\rightarrow{\mathbb{R}}^{d_{c}} be the embedding function of the i𝑖i-th client, which can be different regarding to different clients. dxsubscript𝑑𝑥d_{x} and dcsubscript𝑑𝑐d_{c} represent the dimension of the input x𝑥x and the prototype, respectively. They should be the same for all clients. gi(νi):dcdy:subscript𝑔𝑖subscript𝜈𝑖superscriptsubscript𝑑𝑐superscriptsubscript𝑑𝑦g_{i}(\nu_{i}):{\mathbb{R}}^{d_{c}}\rightarrow{\mathbb{R}}^{d_{y}} is the decision function for all clients, in which dysubscript𝑑𝑦d_{y} represents the dimension of output y𝑦y. So the labelling function can be written as i(ϕi,νi)=gi(νi)fi(ϕi)subscript𝑖subscriptitalic-ϕ𝑖subscript𝜈𝑖subscript𝑔𝑖subscript𝜈𝑖subscript𝑓𝑖subscriptitalic-ϕ𝑖\mathcal{F}_{i}(\phi_{i},\nu_{i})=g_{i}(\nu_{i})\circ f_{i}(\phi_{i}), and sometimes we use ωisubscript𝜔𝑖\omega_{i} to represent (ϕi,νi)subscriptitalic-ϕ𝑖subscript𝜈𝑖(\phi_{i},\nu_{i}) for short. In the theoretical analysis, we omit the label (j)𝑗(j) of prototype C(j)superscript𝐶𝑗C^{(j)} for convenience, which does not affect the proof. We also use qisubscript𝑞𝑖q_{i} to represent the weight of the prototype for i𝑖i-th client, and pisubscript𝑝𝑖p_{i} to represent the weight of the loss function for the i𝑖i-th client for short.

Therefore, the local loss function of client i𝑖i can be written as:

(ϕi,νi;x,y)=S(i(ϕi,νi;x),y)+λfi(ϕi;x)C¯22,subscriptitalic-ϕ𝑖subscript𝜈𝑖𝑥𝑦subscript𝑆subscript𝑖subscriptitalic-ϕ𝑖subscript𝜈𝑖𝑥𝑦𝜆superscriptsubscriptnormsubscript𝑓𝑖subscriptitalic-ϕ𝑖𝑥¯𝐶22\mathcal{L}(\phi_{i},\nu_{i};x,y)=\mathcal{L}_{S}(\mathcal{F}_{i}(\phi_{i},\nu_{i};x),y)+\lambda\|f_{i}(\phi_{i};x)-\bar{C}\|_{2}^{2}, (1)

in which the global prototype

C¯=i=1mqiCi¯𝐶superscriptsubscript𝑖1𝑚subscript𝑞𝑖subscript𝐶𝑖\bar{C}=\sum_{i=1}^{m}q_{i}C_{i} (2)

with

i=1mqi=i=1m|Di|N=1superscriptsubscript𝑖1𝑚subscript𝑞𝑖superscriptsubscript𝑖1𝑚subscript𝐷𝑖𝑁1\sum_{i=1}^{m}q_{i}=\sum_{i=1}^{m}\frac{|D_{i}|}{N}=1 (3)

and

Ci=1|Di|(x,y)Difi(ϕi;x),subscript𝐶𝑖1subscript𝐷𝑖subscript𝑥𝑦subscript𝐷𝑖subscript𝑓𝑖subscriptitalic-ϕ𝑖𝑥C_{i}=\frac{1}{|D_{i}|}\sum_{(x,y)\in D_{i}}f_{i}(\phi_{i};x), (4)

and it is a constant in \mathcal{L}, changing \mathcal{L} every communication round, which makes the convergence analysis complex.

As for the iteration notation system, we use t𝑡t to represent the communication round, e{1/2,1,2,,Ee\in\{1/2,1,2,\dots,E} to represent the local iterations. There are E𝐸E local iterations in total, so tE+e𝑡𝐸𝑒tE+e refers to the e𝑒e-th local iteration in the communication round t+1𝑡1t+1. Moreover, tE𝑡𝐸tE represents the time step before prototype aggregation at the server, and tE+1/2𝑡𝐸12tE+1/2 represents the time step between prototype aggregation at the server and starting the first iteration on the local model.

Assumptions

Assumption 1.

(Lipschitz Smooth). Each local objective function is L1subscript𝐿1L_{1}-Lipschitz smooth, which also means the gradient of local objective function is L1subscript𝐿1L_{1}-Lipschitz continuous,

t1\displaystyle\|\nabla\mathcal{L}_{{t_{1}}} t22L1ωi,t1ωi,t22,evaluated-atsubscriptsubscript𝑡22subscript𝐿1subscriptnormsubscript𝜔𝑖subscript𝑡1subscript𝜔𝑖subscript𝑡22\displaystyle-\nabla\mathcal{L}_{{t_{2}}}\|_{2}\leq L_{1}\|\omega_{{i,t_{1}}}-\omega_{{i,t_{2}}}\|_{2}, (5)
t1,t2>0,i{1,2,,m},formulae-sequencefor-allsubscript𝑡1subscript𝑡20𝑖12𝑚\displaystyle\forall t_{1},t_{2}>0,i\in\{1,2,\dots,m\},

which implies the following quadratic bound,

t1t2subscriptsubscript𝑡1subscriptsubscript𝑡2\displaystyle\mathcal{L}_{{t_{1}}}-\mathcal{L}_{{t_{2}}} t2,(ωi,t1ωi,t2)+L12ωi,t1ωi,t222,t1,t2>0,i{1,2,,m}.formulae-sequenceabsentsubscriptsubscript𝑡2subscript𝜔𝑖subscript𝑡1subscript𝜔𝑖subscript𝑡2subscript𝐿12superscriptsubscriptnormsubscript𝜔𝑖subscript𝑡1subscript𝜔𝑖subscript𝑡222for-allsubscript𝑡1formulae-sequencesubscript𝑡20𝑖12𝑚\displaystyle\leq\langle\nabla\mathcal{L}_{{t_{2}}},(\omega_{{i,t_{1}}}-\omega_{{i,t_{2}}})\rangle+\frac{L_{1}}{2}\|\omega_{{i,t_{1}}}-\omega_{{i,t_{2}}}\|_{2}^{2},\quad\forall t_{1},t_{2}>0,i\in\{1,2,\dots,m\}. (6)
Assumption 2.

(Unbiased Gradient and Bounded Variance). The stochastic gradient gi,t=(ωi,t,ξt)subscript𝑔𝑖𝑡subscript𝜔𝑖𝑡subscript𝜉𝑡g_{i,t}=\nabla\mathcal{L}(\omega_{i,t},\xi_{t}) is an unbiased estimator of the local gradient for each client. Suppose its expectation

𝔼ξiDi[gi,t]=(ωi,t)=t,i1,2,,m,formulae-sequencesubscript𝔼similar-tosubscript𝜉𝑖subscript𝐷𝑖delimited-[]subscript𝑔𝑖𝑡subscript𝜔𝑖𝑡subscript𝑡for-all𝑖12𝑚\displaystyle{\mathbb{E}}_{\xi_{i}\sim D_{i}}{[}g_{i,t}{]}=\nabla\mathcal{L}(\omega_{i,t})=\nabla\mathcal{L}_{t},\quad\forall i\in{1,2,\dots,m}, (7)

and its variance is bounded by σ2superscript𝜎2\sigma^{2}:

𝔼[gi,t(ωi,t)22]σ2,i{1,2,,m},σ20.formulae-sequence𝔼delimited-[]superscriptsubscriptnormsubscript𝑔𝑖𝑡subscript𝜔𝑖𝑡22superscript𝜎2formulae-sequencefor-all𝑖12𝑚superscript𝜎20\displaystyle{\mathbb{E}}{[}{\|g_{i,t}-\nabla\mathcal{L}(\omega_{i,t})\|}_{2}^{2}{]}\leq\sigma^{2},\quad\forall i\in\{1,2,\dots,m\},\sigma^{2}\geq 0. (8)
Assumption 3.

(Bounded Expectation of Euclidean norm of Stochastic Gradients).The expectation of the stochastic gradient is bounded by G𝐺G:

𝔼[gi,t2]G,i{1,2,,m}.formulae-sequence𝔼delimited-[]subscriptnormsubscript𝑔𝑖𝑡2𝐺for-all𝑖12𝑚{\mathbb{E}}{[}\|g_{i,t}\|_{2}{]}\leq G,\quad\forall i\in\{1,2,\dots,m\}. (9)
Assumption 4.

(Lipschitz Continuity). Each local embedding function is L2subscript𝐿2L_{2}-Lipschitz continuous, that is,

fi(ϕi,t1)fi(ϕi,t2)L2ϕi,t1ϕi,t22,t1,t2>0,i{1,2,,m}.formulae-sequencenormsubscript𝑓𝑖subscriptitalic-ϕ𝑖subscript𝑡1subscript𝑓𝑖subscriptitalic-ϕ𝑖subscript𝑡2subscript𝐿2subscriptnormsubscriptitalic-ϕ𝑖subscript𝑡1subscriptitalic-ϕ𝑖subscript𝑡22for-allsubscript𝑡1formulae-sequencesubscript𝑡20𝑖12𝑚\left\|f_{i}(\phi_{i,t_{1}})-f_{i}(\phi_{i,t_{2}})\right\|\leq L_{2}\|\phi_{i,t_{1}}-\phi_{i,t_{2}}\|_{2},\quad\forall t_{1},t_{2}>0,i\in\{1,2,\dots,m\}. (10)

Assumption 4 is a little strong, but we only use it in a very narrow domain with width of E𝐸E steps of SGD in Lemma 2.

Key Lemmas

Lemma 1.

Let Assumption 1 and 2 hold. From the beginning of communication round t+1𝑡1t+1 to the last local update step, the loss function of an arbitrary client can be bounded as:

𝔼[(t+1)E]tE+1/2(ηL1η22)e=1/2E1tE+e22+L1Eη22σ2.𝔼delimited-[]subscript𝑡1𝐸subscript𝑡𝐸12𝜂subscript𝐿1superscript𝜂22superscriptsubscript𝑒12𝐸1superscriptsubscriptdelimited-∥∥subscript𝑡𝐸𝑒22subscript𝐿1𝐸superscript𝜂22superscript𝜎2\begin{split}{\mathbb{E}}{[}\mathcal{L}_{(t+1)E}{]}\leq\mathcal{L}_{tE+1/2}-(\eta-\frac{L_{1}\eta^{2}}{2})\sum_{e=1/2}^{E-1}\|\nabla\mathcal{L}_{tE+e}\|_{2}^{2}+\frac{L_{1}E\eta^{2}}{2}\sigma^{2}.\end{split} (11)
Proof.

Due to the fact that this lemma is for an arbitrary client, so client notation i𝑖i is omitted. Let ωt+1=ωtηgtsubscript𝜔𝑡1subscript𝜔𝑡𝜂subscript𝑔𝑡\omega_{t+1}=\omega_{t}-\eta g_{t}, then

tE+1(a)tE+1/2+tE+1/2,(ωtE+1ωtE+1/2)+L12ωtE+1ωtE+1/222=tE+1/2η1tE+1/2,gtE+1/2+L12ηgtE+1/222,superscript𝑎subscript𝑡𝐸1subscript𝑡𝐸12subscript𝑡𝐸12subscript𝜔𝑡𝐸1subscript𝜔𝑡𝐸12subscript𝐿12superscriptsubscriptdelimited-∥∥subscript𝜔𝑡𝐸1subscript𝜔𝑡𝐸1222subscript𝑡𝐸12𝜂subscriptsubscript1𝑡𝐸12subscript𝑔𝑡𝐸12subscript𝐿12superscriptsubscriptdelimited-∥∥𝜂subscript𝑔𝑡𝐸1222\begin{split}\mathcal{L}_{tE+1}&\stackrel{{\scriptstyle(a)}}{{\leq}}\mathcal{L}_{tE+1/2}+\langle\nabla\mathcal{L}_{tE+1/2},(\omega_{tE+1}-\omega_{tE+1/2})\rangle+\frac{L_{1}}{2}\|\omega_{tE+1}-\omega_{tE+1/2}\|_{2}^{2}\\ &=\mathcal{L}_{tE+1/2}-\eta\langle\nabla\mathcal{L_{1}}_{tE+1/2},g_{tE+1/2}\rangle+\frac{L_{1}}{2}\|\eta g_{tE+1/2}\|_{2}^{2},\end{split} (12)

where (a) follows from the quadratic L1subscript𝐿1L_{1}-Lipschitz smooth bound in Assumption 1. Taking expectation of both sides of the above equation on the random variable ξtE+1/2subscript𝜉𝑡𝐸12\xi_{tE+1/2}, we have

𝔼[tE+1]𝔼delimited-[]subscript𝑡𝐸1\displaystyle{\mathbb{E}}{[}\mathcal{L}_{tE+1}{]} tE+1/2η𝔼[1tE+1/2,gtE+1/2]+L1η22𝔼[gtE+1/222]absentsubscript𝑡𝐸12𝜂𝔼delimited-[]subscriptsubscript1𝑡𝐸12subscript𝑔𝑡𝐸12subscript𝐿1superscript𝜂22𝔼delimited-[]superscriptsubscriptnormsubscript𝑔𝑡𝐸1222\displaystyle\leq\mathcal{L}_{tE+1/2}-\eta{\mathbb{E}}{[}\langle\nabla\mathcal{L_{1}}_{tE+1/2},g_{tE+1/2}\rangle{]}+\frac{L_{1}\eta^{2}}{2}{\mathbb{E}}{[}\|g_{tE+1/2}\|_{2}^{2}{]} (13)
=(b)tE+1/2ηtE+1/222+L1η22𝔼[gi,tE+1/222]superscript𝑏absentsubscript𝑡𝐸12𝜂superscriptsubscriptnormsubscript𝑡𝐸1222subscript𝐿1superscript𝜂22𝔼delimited-[]superscriptsubscriptnormsubscript𝑔𝑖𝑡𝐸1222\displaystyle\stackrel{{\scriptstyle(b)}}{{=}}\mathcal{L}_{tE+1/2}-\eta\|\nabla\mathcal{L}_{tE+1/2}\|_{2}^{2}+\frac{L_{1}\eta^{2}}{2}{\mathbb{E}}{[}\|g_{i,tE+1/2}\|_{2}^{2}{]} (14)
(c)tE+1/2ηtE+1/222+L1η22(tE+1/222+Var(gi,tE+1/2))superscript𝑐absentsubscript𝑡𝐸12𝜂superscriptsubscriptnormsubscript𝑡𝐸1222subscript𝐿1superscript𝜂22superscriptsubscriptnormsubscript𝑡𝐸1222𝑉𝑎𝑟subscript𝑔𝑖𝑡𝐸12\displaystyle\stackrel{{\scriptstyle(c)}}{{\leq}}\mathcal{L}_{tE+1/2}-\eta\|\nabla\mathcal{L}_{tE+1/2}\|_{2}^{2}+\frac{L_{1}\eta^{2}}{2}(\|\nabla\mathcal{L}_{tE+1/2}\|_{2}^{2}+Var(g_{i,tE+1/2})) (15)
=tE+1/2(ηL1η22)tE+1/222+L1η22Var(gi,tE+1/2)absentsubscript𝑡𝐸12𝜂subscript𝐿1superscript𝜂22superscriptsubscriptnormsubscript𝑡𝐸1222subscript𝐿1superscript𝜂22𝑉𝑎𝑟subscript𝑔𝑖𝑡𝐸12\displaystyle=\mathcal{L}_{tE+1/2}-(\eta-\frac{L_{1}\eta^{2}}{2})\|\nabla\mathcal{L}_{tE+1/2}\|_{2}^{2}+\frac{L_{1}\eta^{2}}{2}Var(g_{i,tE+1/2}) (16)
(d)tE+1/2(ηL1η22)tE+1/222+L1η22σ2,superscript𝑑absentsubscript𝑡𝐸12𝜂subscript𝐿1superscript𝜂22superscriptsubscriptnormsubscript𝑡𝐸1222subscript𝐿1superscript𝜂22superscript𝜎2\displaystyle\stackrel{{\scriptstyle(d)}}{{\leq}}\mathcal{L}_{tE+1/2}-(\eta-\frac{L_{1}\eta^{2}}{2})\|\nabla\mathcal{L}_{tE+1/2}\|_{2}^{2}+\frac{L_{1}\eta^{2}}{2}\sigma^{2}, (17)

where (b) follows from Assumption 2, (c) follows from Var(x)=𝔼[x2](𝔼[x])2𝑉𝑎𝑟𝑥𝔼delimited-[]superscript𝑥2superscript𝔼delimited-[]𝑥2Var(x)={\mathbb{E}}{[}x^{2}{]}-({\mathbb{E}{[}x{]}})^{2}, (d) follows from Assumption 2. Take expectation of ω𝜔\omega on both sides. Then, by telescoping of E𝐸E steps, we have,

𝔼[(t+1)E]tE+1/2(ηL1η22)e=1/2E1tE+e22+L1Eη22σ2.𝔼delimited-[]subscript𝑡1𝐸subscript𝑡𝐸12𝜂subscript𝐿1superscript𝜂22superscriptsubscript𝑒12𝐸1superscriptsubscriptdelimited-∥∥subscript𝑡𝐸𝑒22subscript𝐿1𝐸superscript𝜂22superscript𝜎2\begin{split}{\mathbb{E}}{[}\mathcal{L}_{(t+1)E}{]}\leq\mathcal{L}_{tE+1/2}-(\eta-\frac{L_{1}\eta^{2}}{2})\sum_{e=1/2}^{E-1}\|\nabla\mathcal{L}_{tE+e}\|_{2}^{2}+\frac{L_{1}E\eta^{2}}{2}\sigma^{2}.\end{split} (18)

Lemma 2.

Let Assumption 3 and 4 hold. After the prototype aggregation at the server, the loss function of an arbitrary client can be bounded as:

𝔼[(t+1)E+1/2](t+1)E+λL2ηEG𝔼delimited-[]subscript𝑡1𝐸12subscript𝑡1𝐸𝜆subscript𝐿2𝜂𝐸𝐺{\mathbb{E}}{[}\mathcal{L}_{(t+1)E+1/2}{]}\leq\mathcal{L}_{(t+1)E}+\lambda L_{2}\eta EG (19)
Proof.
(t+1)E+1/2subscript𝑡1𝐸12\displaystyle\mathcal{L}_{(t+1)E+1/2} =(t+1)E+(t+1)E+1/2(t+1)Eabsentsubscript𝑡1𝐸subscript𝑡1𝐸12subscript𝑡1𝐸\displaystyle=\mathcal{L}_{(t+1)E}+\mathcal{L}_{(t+1)E+1/2}-\mathcal{L}_{(t+1)E} (20)
=(a)(t+1)E+λfi(ϕi,(t+1)E)C¯t+22λfi(ϕi,(t+1)E)C¯t+12superscript𝑎absentsubscript𝑡1𝐸𝜆subscriptnormsubscript𝑓𝑖subscriptitalic-ϕ𝑖𝑡1𝐸subscript¯𝐶𝑡22𝜆subscriptnormsubscript𝑓𝑖subscriptitalic-ϕ𝑖𝑡1𝐸subscript¯𝐶𝑡12\displaystyle\stackrel{{\scriptstyle(a)}}{{=}}\mathcal{L}_{(t+1)E}+\lambda\|f_{i}(\phi_{i,(t+1)E})-\bar{C}_{t+2}\|_{2}-\lambda\|f_{i}(\phi_{i,(t+1)E})-\bar{C}_{t+1}\|_{2} (21)
(b)(t+1)E+λC¯t+2C¯t+12superscript𝑏absentsubscript𝑡1𝐸𝜆subscriptnormsubscript¯𝐶𝑡2subscript¯𝐶𝑡12\displaystyle\stackrel{{\scriptstyle(b)}}{{\leq}}\mathcal{L}_{(t+1)E}+\lambda\|\bar{C}_{t+2}-\bar{C}_{t+1}\|_{2} (22)
=(c)(t+1)E+λi=1mqiCi,(t+1)Ei=1mqiCi,tE2superscript𝑐absentsubscript𝑡1𝐸𝜆subscriptnormsuperscriptsubscript𝑖1𝑚subscript𝑞𝑖subscript𝐶𝑖𝑡1𝐸superscriptsubscript𝑖1𝑚subscript𝑞𝑖subscript𝐶𝑖𝑡𝐸2\displaystyle\stackrel{{\scriptstyle(c)}}{{=}}\mathcal{L}_{(t+1)E}+\lambda\|\sum_{i=1}^{m}q_{i}C_{i,(t+1)E}-\sum_{i=1}^{m}q_{i}C_{i,tE}\|_{2} (23)
=(t+1)E+λi=1mqi(Ci,(t+1)ECi,tE)2absentsubscript𝑡1𝐸𝜆subscriptnormsuperscriptsubscript𝑖1𝑚subscript𝑞𝑖subscript𝐶𝑖𝑡1𝐸subscript𝐶𝑖𝑡𝐸2\displaystyle=\mathcal{L}_{(t+1)E}+\lambda\|\sum_{i=1}^{m}q_{i}(C_{i,(t+1)E}-C_{i,tE})\|_{2} (24)
=(d)(t+1)E+λi=1mqi1|Di|k=1|Di|(fi(ϕi,(t+1)E;xi,k)fi(ϕi,tE;xi,k)2\displaystyle\stackrel{{\scriptstyle(d)}}{{=}}\mathcal{L}_{(t+1)E}+\lambda\|\sum_{i=1}^{m}q_{i}\frac{1}{|D_{i}|}\sum_{k=1}^{|D_{i}|}(f_{i}(\phi_{i,(t+1)E};x_{i,k})-f_{i}(\phi_{i,tE};x_{i,k})\|_{2} (25)
(e)(t+1)E+λi=1mqi|Di|k=1|Di|fi(ϕi,(t+1)E;xi,k)fi(ϕi,tE;xi,k)2superscript𝑒absentsubscript𝑡1𝐸𝜆superscriptsubscript𝑖1𝑚subscript𝑞𝑖subscript𝐷𝑖superscriptsubscript𝑘1subscript𝐷𝑖subscriptnormsubscript𝑓𝑖subscriptitalic-ϕ𝑖𝑡1𝐸subscript𝑥𝑖𝑘subscript𝑓𝑖subscriptitalic-ϕ𝑖𝑡𝐸subscript𝑥𝑖𝑘2\displaystyle\stackrel{{\scriptstyle(e)}}{{\leq}}\mathcal{L}_{(t+1)E}+\lambda\sum_{i=1}^{m}\frac{q_{i}}{|D_{i}|}\sum_{k=1}^{|D_{i}|}\|f_{i}(\phi_{i,(t+1)E};x_{i,k})-f_{i}(\phi_{i,tE};x_{i,k})\|_{2} (26)
(f)(t+1)E+λL2i=1mqiϕi,(t+1)Eϕi,tE2superscript𝑓absentsubscript𝑡1𝐸𝜆subscript𝐿2superscriptsubscript𝑖1𝑚subscript𝑞𝑖subscriptnormsubscriptitalic-ϕ𝑖𝑡1𝐸subscriptitalic-ϕ𝑖𝑡𝐸2\displaystyle\stackrel{{\scriptstyle(f)}}{{\leq}}\mathcal{L}_{(t+1)E}+\lambda L_{2}\sum_{i=1}^{m}{q_{i}}\|\phi_{i,(t+1)E}-\phi_{i,tE}\|_{2} (27)
(g)(t+1)E+λL2i=1mqiωi,(t+1)Eωi,tE2superscript𝑔absentsubscript𝑡1𝐸𝜆subscript𝐿2superscriptsubscript𝑖1𝑚subscript𝑞𝑖subscriptnormsubscript𝜔𝑖𝑡1𝐸subscript𝜔𝑖𝑡𝐸2\displaystyle\stackrel{{\scriptstyle(g)}}{{\leq}}\mathcal{L}_{(t+1)E}+\lambda L_{2}\sum_{i=1}^{m}{q_{i}}\|\omega_{i,(t+1)E}-\omega_{i,tE}\|_{2} (28)
=(t+1)E+λL2ηi=1mqie=1/2E1gi,tE+e2absentsubscript𝑡1𝐸𝜆subscript𝐿2𝜂superscriptsubscript𝑖1𝑚subscript𝑞𝑖subscriptnormsuperscriptsubscript𝑒12𝐸1subscript𝑔𝑖𝑡𝐸𝑒2\displaystyle=\mathcal{L}_{(t+1)E}+\lambda L_{2}\eta\sum_{i=1}^{m}q_{i}\|\sum_{e=1/2}^{E-1}g_{i,tE+e}\|_{2} (29)
(h)(t+1)E+λL2ηi=1mqie=1/2E1gi,tE+e2superscriptabsentsubscript𝑡1𝐸𝜆subscript𝐿2𝜂superscriptsubscript𝑖1𝑚subscript𝑞𝑖superscriptsubscript𝑒12𝐸1subscriptnormsubscript𝑔𝑖𝑡𝐸𝑒2\displaystyle\stackrel{{\scriptstyle(h)}}{{\leq}}\mathcal{L}_{(t+1)E}+\lambda L_{2}\eta\sum_{i=1}^{m}q_{i}\sum_{e=1/2}^{E-1}\|g_{i,tE+e}\|_{2} (30)

Take expectations of random variable ξ𝜉\xi on both sides, then

𝔼[(t+1)E+1/2]𝔼delimited-[]subscript𝑡1𝐸12\displaystyle{\mathbb{E}}{[}\mathcal{L}_{(t+1)E+1/2}{]} (t+1)E+λL2ηi=1mqie=1/2E1𝔼[gi,tE+e2]absentsubscript𝑡1𝐸𝜆subscript𝐿2𝜂superscriptsubscript𝑖1𝑚subscript𝑞𝑖superscriptsubscript𝑒12𝐸1𝔼delimited-[]subscriptnormsubscript𝑔𝑖𝑡𝐸𝑒2\displaystyle{\leq}\mathcal{L}_{(t+1)E}+\lambda L_{2}\eta\sum_{i=1}^{m}q_{i}\sum_{e=1/2}^{E-1}{\mathbb{E}}{[}\|g_{i,tE+e}\|_{2}{]} (32)
(i)(t+1)E+λL2ηEG,superscript𝑖absentsubscript𝑡1𝐸𝜆subscript𝐿2𝜂𝐸𝐺\displaystyle\stackrel{{\scriptstyle(i)}}{{\leq}}\mathcal{L}_{(t+1)E}+\lambda L_{2}\eta EG, (33)

where (a) follows from the definition of local loss function in Eq. 1, (b) follows from ab2ac2bc2subscriptnorm𝑎𝑏2subscriptnorm𝑎𝑐2subscriptnorm𝑏𝑐2\|a-b\|_{2}-\|a-c\|_{2}\leq\|b-c\|_{2}, (c) follows from the definition of global prototype in Eq. 2, (d) follows from the definition of local prototype in Eq. 4, (e) and (h) follow from ai2ai2subscriptnormsubscript𝑎𝑖2subscriptnormsubscript𝑎𝑖2\|\sum a_{i}\|_{2}\leq\sum{\|a_{i}\|_{2}}, (f) follows from L2subscript𝐿2L_{2}-Lipschitz continuity in Assumption 4, (g) follows from the fact that ϕisubscriptitalic-ϕ𝑖\phi_{i} is a subset of ωisubscript𝜔𝑖\omega_{i}, (i) follows from Assumption 3. ∎

Theorems

Theorem 1.

(One-round deviation). Let Assumption 1 to 4 hold. For an arbitrary client, after every communication round, we have,

𝔼[(t+1)E+1/2]tE+1/2(ηL1η22)e=1/2E1tE+e22+L1Eη22σ2+λL2ηEG.𝔼delimited-[]subscript𝑡1𝐸12subscript𝑡𝐸12𝜂subscript𝐿1superscript𝜂22superscriptsubscript𝑒12𝐸1superscriptsubscriptdelimited-∥∥subscript𝑡𝐸𝑒22subscript𝐿1𝐸superscript𝜂22superscript𝜎2𝜆subscript𝐿2𝜂𝐸𝐺\begin{split}{\mathbb{E}}[\mathcal{L}_{(t+1)E+1/2}]\leq\mathcal{L}_{tE+1/2}-(\eta-\frac{L_{1}\eta^{2}}{2})\sum_{e=1/2}^{E-1}\|\nabla\mathcal{L}_{tE+e}\|_{2}^{2}+\frac{L_{1}E\eta^{2}}{2}\sigma^{2}+\lambda L_{2}\eta EG.\end{split} (34)
Corollary 1.

(Non-convex FedProto convergence). The loss function \mathcal{L} of arbitrary client monotonously decreases in every communication round when

ηe<2(e=1/2etE+e22λL2EG)L1(e=1/2etE+e22+Eσ2),e=1/2,1,,E1formulae-sequencesubscript𝜂superscript𝑒2superscriptsubscript𝑒12superscript𝑒superscriptsubscriptnormsubscript𝑡𝐸𝑒22𝜆subscript𝐿2𝐸𝐺subscript𝐿1superscriptsubscript𝑒12superscript𝑒superscriptsubscriptnormsubscript𝑡𝐸𝑒22𝐸superscript𝜎2superscript𝑒121𝐸1\eta_{e^{\prime}}<\frac{2(\sum_{e=1/2}^{e^{\prime}}\|\nabla\mathcal{L}_{tE+e}\|_{2}^{2}-\lambda L_{2}EG)}{L_{1}(\sum_{e=1/2}^{e^{\prime}}\|\nabla\mathcal{L}_{tE+e}\|_{2}^{2}+E\sigma^{2})},\ e^{\prime}=1/2,1,\dots,E-1 (35)

and

λt<tE+1/222L2EG.subscript𝜆𝑡superscriptsubscriptnormsubscript𝑡𝐸1222subscript𝐿2𝐸𝐺\lambda_{t}<\frac{\|\nabla\mathcal{L}_{tE+1/2}\|_{2}^{2}}{L_{2}EG}. (36)

Thus, the loss function converges.

Theorem 2.

(Non-convex convergence rate of FedProto). Let Assumption 1 to 4 hold and Δ=0Δsubscript0superscript\Delta=\mathcal{L}_{0}-\mathcal{L}^{*}. For an arbitrary client, given any ϵ>0italic-ϵ0\epsilon>0, after

T=2ΔEϵ(2ηL1η2)Eη(L1ησ2+2λL2G)𝑇2Δ𝐸italic-ϵ2𝜂subscript𝐿1superscript𝜂2𝐸𝜂subscript𝐿1𝜂superscript𝜎22𝜆subscript𝐿2𝐺T=\frac{2\Delta}{E\epsilon(2\eta-L_{1}\eta^{2})-E\eta(L_{1}\eta\sigma^{2}+2\lambda L_{2}G)} (37)

communication rounds of FedProto, we have

1TEt=0T1e=1/2E1𝔼[tE+e22]<ϵ,1𝑇𝐸superscriptsubscript𝑡0𝑇1superscriptsubscript𝑒12𝐸1𝔼delimited-[]superscriptsubscriptnormsubscript𝑡𝐸𝑒22italic-ϵ\frac{1}{TE}\sum_{t=0}^{T-1}\sum_{e=1/2}^{E-1}{\mathbb{E}}[\|\nabla\mathcal{L}_{tE+e}\|_{2}^{2}]<\epsilon, (38)

if

η<2(ϵλL2G)L1(ϵ+σ2),𝜂2italic-ϵ𝜆subscript𝐿2𝐺subscript𝐿1italic-ϵsuperscript𝜎2\eta<\frac{2(\epsilon-\lambda L_{2}G)}{L_{1}(\epsilon+\sigma^{2})}, (39)

and

λ<ϵL2G.𝜆italic-ϵsubscript𝐿2𝐺\lambda<\frac{\epsilon}{L_{2}G}. (40)

Completing the Proof of Theorem 1 and Corollary 1

Proof.

Taking expectation of ω𝜔\omega on both sides in Lemma 1 and 2, then sum them, we can easily get

𝔼[(t+1)E+1/2]tE+1/2(ηL1η22)e=1/2E1tE+e22+L1Eη22σ2+λL2ηEG𝔼delimited-[]subscript𝑡1𝐸12subscript𝑡𝐸12𝜂subscript𝐿1superscript𝜂22superscriptsubscript𝑒12𝐸1superscriptsubscriptdelimited-∥∥subscript𝑡𝐸𝑒22subscript𝐿1𝐸superscript𝜂22superscript𝜎2𝜆subscript𝐿2𝜂𝐸𝐺\begin{split}{\mathbb{E}}[\mathcal{L}_{(t+1)E+1/2}]\leq\mathcal{L}_{tE+1/2}-(\eta-\frac{L_{1}\eta^{2}}{2})\sum_{e=1/2}^{E-1}\|\nabla\mathcal{L}_{tE+e}\|_{2}^{2}+\frac{L_{1}E\eta^{2}}{2}\sigma^{2}+\lambda L_{2}\eta EG\end{split} (41)

Then, to make sure (ηL1η22)e=1/2E1tE+e22+L1Eη22σ2+λL2ηEG0𝜂subscript𝐿1superscript𝜂22superscriptsubscript𝑒12𝐸1superscriptsubscriptnormsubscript𝑡𝐸𝑒22subscript𝐿1𝐸superscript𝜂22superscript𝜎2𝜆subscript𝐿2𝜂𝐸𝐺0-(\eta-\frac{L_{1}\eta^{2}}{2})\sum_{e=1/2}^{E-1}\|\nabla\mathcal{L}_{tE+e}\|_{2}^{2}+\frac{L_{1}E\eta^{2}}{2}\sigma^{2}+\lambda L_{2}\eta EG\leq 0, we get

η<2(e=1/2E1tE+e22λL2EG)L1(e=1/2E1tE+e22+Eσ2),𝜂2superscriptsubscript𝑒12𝐸1superscriptsubscriptnormsubscript𝑡𝐸𝑒22𝜆subscript𝐿2𝐸𝐺subscript𝐿1superscriptsubscript𝑒12𝐸1superscriptsubscriptnormsubscript𝑡𝐸𝑒22𝐸superscript𝜎2\eta<\frac{2(\sum_{e=1/2}^{E-1}\|\nabla\mathcal{L}_{tE+e}\|_{2}^{2}-\lambda L_{2}EG)}{L_{1}(\sum_{e=1/2}^{E-1}\|\nabla\mathcal{L}_{tE+e}\|_{2}^{2}+E\sigma^{2})}, (42)

and

λ<e=1/2E1tE+e22L2EG.𝜆superscriptsubscript𝑒12𝐸1superscriptsubscriptnormsubscript𝑡𝐸𝑒22subscript𝐿2𝐸𝐺\lambda<\frac{\sum_{e=1/2}^{E-1}\|\nabla\mathcal{L}_{tE+e}\|_{2}^{2}}{L_{2}EG}. (43)

In practice, we use

ηe<2(e=1/2etE+e22λL2EG)L1(e=1/2etE+e22+Eσ2),e=1/2,1,,E1formulae-sequencesubscript𝜂superscript𝑒2superscriptsubscript𝑒12superscript𝑒superscriptsubscriptnormsubscript𝑡𝐸𝑒22𝜆subscript𝐿2𝐸𝐺subscript𝐿1superscriptsubscript𝑒12superscript𝑒superscriptsubscriptnormsubscript𝑡𝐸𝑒22𝐸superscript𝜎2superscript𝑒121𝐸1\eta_{e^{\prime}}<\frac{2(\sum_{e=1/2}^{e^{\prime}}\|\nabla\mathcal{L}_{tE+e}\|_{2}^{2}-\lambda L_{2}EG)}{L_{1}(\sum_{e=1/2}^{e^{\prime}}\|\nabla\mathcal{L}_{tE+e}\|_{2}^{2}+E\sigma^{2})},\ e^{\prime}=1/2,1,\dots,E-1 (44)

and

λt<tE+1/222L2EG.subscript𝜆𝑡superscriptsubscriptnormsubscript𝑡𝐸1222subscript𝐿2𝐸𝐺\lambda_{t}<\frac{\|\nabla\mathcal{L}_{tE+1/2}\|_{2}^{2}}{L_{2}EG}. (45)

So, the convergence of \mathcal{L} holds. ∎

Completing the Proof of Theorem 2

Proof.

Take expectation of ω𝜔\omega on both sides in Eq. 34, then telescope considering the communication round from t=0𝑡0t=0 to t=T1𝑡𝑇1t=T-1 with the timestep from e=1/2𝑒12e=1/2 to t=E𝑡𝐸t=E in each communication round, we have

1TEt=0T1e=1/2E1𝔼[tE+e22]1𝑇𝐸superscriptsubscript𝑡0𝑇1superscriptsubscript𝑒12𝐸1𝔼delimited-[]superscriptsubscriptnormsubscript𝑡𝐸𝑒22\displaystyle\frac{1}{TE}\sum_{t=0}^{T-1}\sum_{e=1/2}^{E-1}{\mathbb{E}}[\|\nabla\mathcal{L}_{tE+e}\|_{2}^{2}] 1TEt=0T1(tE+1/2𝔼[(t+1)E+1/2])+L1η22σ2+λL2ηGηL1η22.absent1𝑇𝐸superscriptsubscript𝑡0𝑇1subscript𝑡𝐸12𝔼delimited-[]subscript𝑡1𝐸12subscript𝐿1superscript𝜂22superscript𝜎2𝜆subscript𝐿2𝜂𝐺𝜂subscript𝐿1superscript𝜂22\displaystyle\leq\frac{\frac{1}{TE}\sum_{t=0}^{T-1}(\mathcal{L}_{tE+1/2}-{\mathbb{E}}[\mathcal{L}_{(t+1)E+1/2}])+\frac{L_{1}\eta^{2}}{2}\sigma^{2}+\lambda L_{2}\eta G}{\eta-\frac{L_{1}\eta^{2}}{2}}. (46)

Given any ϵ>0italic-ϵ0\epsilon>0, let

1TEt=0T1(tE+1/2𝔼[(t+1)E+1/2])+L1η22σ2+λL2ηGηL1η22<ϵ,1𝑇𝐸superscriptsubscript𝑡0𝑇1subscript𝑡𝐸12𝔼delimited-[]subscript𝑡1𝐸12subscript𝐿1superscript𝜂22superscript𝜎2𝜆subscript𝐿2𝜂𝐺𝜂subscript𝐿1superscript𝜂22italic-ϵ\frac{\frac{1}{TE}\sum_{t=0}^{T-1}(\mathcal{L}_{tE+1/2}-{\mathbb{E}}[\mathcal{L}_{(t+1)E+1/2}])+\frac{L_{1}\eta^{2}}{2}\sigma^{2}+\lambda L_{2}\eta G}{\eta-\frac{L_{1}\eta^{2}}{2}}<\epsilon, (47)

that is

2TEt=0T1(tE+1/2𝔼[(t+1)E+1/2])+L1η2σ2+2λL2ηG2ηL1η2<ϵ.2𝑇𝐸superscriptsubscript𝑡0𝑇1subscript𝑡𝐸12𝔼delimited-[]subscript𝑡1𝐸12subscript𝐿1superscript𝜂2superscript𝜎22𝜆subscript𝐿2𝜂𝐺2𝜂subscript𝐿1superscript𝜂2italic-ϵ\frac{\frac{2}{TE}\sum_{t=0}^{T-1}(\mathcal{L}_{tE+1/2}-{\mathbb{E}}[\mathcal{L}_{(t+1)E+1/2}])+{L_{1}\eta^{2}}\sigma^{2}+2\lambda L_{2}\eta G}{2\eta-{L_{1}\eta^{2}}}<\epsilon. (48)

Let Δ=0Δsubscript0superscript\Delta=\mathcal{L}_{0}-\mathcal{L}^{*}. Since t=0T1(tE+1/2𝔼[(t+1)E+1/2])Δsuperscriptsubscript𝑡0𝑇1subscript𝑡𝐸12𝔼delimited-[]subscript𝑡1𝐸12Δ\sum_{t=0}^{T-1}(\mathcal{L}_{tE+1/2}-{\mathbb{E}}[\mathcal{L}_{(t+1)E+1/2}])\leq\Delta, the above equation holds when

2ΔTE+L1η2σ2+2λL2ηG2ηL1η2<ϵ,2Δ𝑇𝐸subscript𝐿1superscript𝜂2superscript𝜎22𝜆subscript𝐿2𝜂𝐺2𝜂subscript𝐿1superscript𝜂2italic-ϵ\displaystyle\frac{\frac{2\Delta}{TE}+{L_{1}\eta^{2}}\sigma^{2}+2\lambda L_{2}\eta G}{2\eta-{L_{1}\eta^{2}}}<\epsilon, (49)

that is

T>2ΔEϵ(2ηL1η2)Eη(L1ησ2+2λL2G).𝑇2Δ𝐸italic-ϵ2𝜂subscript𝐿1superscript𝜂2𝐸𝜂subscript𝐿1𝜂superscript𝜎22𝜆subscript𝐿2𝐺\displaystyle T>\frac{2\Delta}{E\epsilon(2\eta-L_{1}\eta^{2})-E\eta(L_{1}\eta\sigma^{2}+2\lambda L_{2}G)}. (50)

So, we have

1TEt=0T1e=1/2E1𝔼[tE+e22]<ϵ,1𝑇𝐸superscriptsubscript𝑡0𝑇1superscriptsubscript𝑒12𝐸1𝔼delimited-[]superscriptsubscriptnormsubscript𝑡𝐸𝑒22italic-ϵ\frac{1}{TE}\sum_{t=0}^{T-1}\sum_{e=1/2}^{E-1}{\mathbb{E}}[\|\nabla\mathcal{L}_{tE+e}\|_{2}^{2}]<\epsilon, (51)

when

η<2(ϵλL2G)L1(ϵ+σ2),𝜂2italic-ϵ𝜆subscript𝐿2𝐺subscript𝐿1italic-ϵsuperscript𝜎2\eta<\frac{2(\epsilon-\lambda L_{2}G)}{L_{1}(\epsilon+\sigma^{2})}, (52)

and

λ<ϵL2G.𝜆italic-ϵsubscript𝐿2𝐺\lambda<\frac{\epsilon}{L_{2}G}. (53)